Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon DataLoader: avoid recursionlimit error #12622

Merged
merged 6 commits into from
Sep 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ def _recursive_fork_recordio(obj, depth, max_depth=1000):
def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
# re-fork a new recordio handler in new process if applicable
_recursive_fork_recordio(dataset, 0, 1000)
# for a dataset with transform function, the depth of MXRecordIO is 1
# for a lazy transformer, the depth is 2
# for a user defined transformer, the depth is unknown, try a reasonable depth
limit = sys.getrecursionlimit()
max_recursion_depth = min(limit - 5, max(10, limit // 2))
_recursive_fork_recordio(dataset, 0, max_recursion_depth)

while True:
idx, samples = key_queue.get()
Expand Down
32 changes: 32 additions & 0 deletions python/mxnet/recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ def open(self):
def __del__(self):
self.close()

def __getstate__(self):
"""Override pickling behavior."""
# pickling pointer is not allowed
is_open = self.is_open
self.close()
d = dict(self.__dict__)
d['is_open'] = is_open
uri = self.uri.value
try:
uri = uri.decode('utf-8')
except AttributeError:
pass
del d['handle']
d['uri'] = uri
return d

def __setstate__(self, d):
"""Restore from pickled."""
self.__dict__ = d
is_open = d['is_open']
self.is_open = False
self.handle = RecordIOHandle()
self.uri = c_str(self.uri)
if is_open:
self.open()

def close(self):
"""Closes the record file."""
if not self.is_open:
Expand Down Expand Up @@ -217,6 +243,12 @@ def close(self):
super(MXIndexedRecordIO, self).close()
self.fidx.close()

def __getstate__(self):
"""Override pickling behavior."""
d = super(MXIndexedRecordIO, self).__getstate__()
d['fidx'] = None
return d

def seek(self, idx):
"""Sets the current read pointer position.

Expand Down
49 changes: 31 additions & 18 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,39 @@ def test_recordimage_dataset():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

def _dataset_transform_fn(x, y):
"""Named transform function since lambda function cannot be pickled."""
return x, y

@with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
# This test is pointless on Windows because Windows doesn't fork
if platform.system() != 'Windows':
recfile = prepare_record()
dataset = gluon.data.vision.ImageRecordDataset(recfile)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

# with transform
fn = lambda x, y : (x, y)
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
recfile = prepare_record()
dataset = gluon.data.vision.ImageRecordDataset(recfile)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

# with transform
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

# try limit recursion depth
import sys
old_limit = sys.getrecursionlimit()
sys.setrecursionlimit(500) # this should be smaller than any default value used in python
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
sys.setrecursionlimit(old_limit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

little worried, if the test case fail after setting recursion limit to 100 but before resetting it back to old_limit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a fail is a fail, no need to care if rest tests fail right?

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy Sep 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, you are right, I was thinking like a regular user functionality not like a unit test.

LGTM. Thanks.


@with_seed()
def test_sampler():
Expand Down