Skip to content

Commit

Permalink
Fix lazy record io when used with dataloader and multi_worker > 0 (ap…
Browse files Browse the repository at this point in the history
…ache#12554)

* temp solution to record file dataset with multi worker

* fix cascaded dataset for gluon dataloader, when multi_worker > 0 is used
  • Loading branch information
zhreshold authored and lebeg committed Nov 5, 2018
1 parent 0a286a0 commit 435f0ce
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
20 changes: 18 additions & 2 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from . import sampler as _sampler
from ... import nd, context
from ...recordio import MXRecordIO

if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_ndarray(*args):
Expand Down Expand Up @@ -158,10 +159,24 @@ def _as_in_context(data, ctx):
return [_as_in_context(d, ctx) for d in data]
return data

def _recursive_fork_recordio(obj, depth, max_depth=1000):
"""Recursively find instance of MXRecordIO and reset file handler.
This is required for MXRecordIO which holds a C pointer to a opened file after fork.
"""
if depth >= max_depth:
return
if isinstance(obj, MXRecordIO):
obj.close()
obj.open() # re-obtain file hanlder in new process
elif (hasattr(obj, '__dict__')):
for _, v in obj.__dict__.items():
_recursive_fork_recordio(v, depth + 1, max_depth)

def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
if hasattr(dataset, '_fork') and callable(dataset._fork):
dataset._fork()
# re-fork a new recordio handler in new process if applicable
_recursive_fork_recordio(dataset, 0, 1000)

while True:
idx, samples = key_queue.get()
if idx is None:
Expand All @@ -181,6 +196,7 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False):
batch = _as_in_context(batch, context.cpu())
data_buffer[idx] = batch


class _MultiWorkerIter(object):
"""Interal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
Expand Down
8 changes: 0 additions & 8 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ def base_fn(x, *args):
return fn(x)
return self.transform(base_fn, lazy)

def _fork(self):
"""Protective operations required when launching multiprocess workers."""
# for non file descriptor related datasets, just skip
pass


class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
Expand Down Expand Up @@ -180,9 +175,6 @@ class RecordFileDataset(Dataset):
def __init__(self, filename):
self.idx_file = os.path.splitext(filename)[0] + '.idx'
self.filename = filename
self._fork()

def _fork(self):
self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')

def __getitem__(self, idx):
Expand Down
12 changes: 11 additions & 1 deletion tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def prepare_record():
@with_seed()
def test_recordimage_dataset():
recfile = prepare_record()
dataset = gluon.data.vision.ImageRecordDataset(recfile)
fn = lambda x, y : (x, y)
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
loader = gluon.data.DataLoader(dataset, 1)

for i, (x, y) in enumerate(loader):
Expand All @@ -84,6 +85,15 @@ def test_recordimage_dataset_with_data_loader_multiworker():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

# with transform
fn = lambda x, y : (x, y)
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

@with_seed()
def test_sampler():
seq_sampler = gluon.data.SequentialSampler(10)
Expand Down

0 comments on commit 435f0ce

Please sign in to comment.