Skip to content

Commit

Permalink
improve dataloader signals and messages (apache#16114)
Browse files Browse the repository at this point in the history
* improve dataloader signals and messages

* address comments

* fix spawn tests on windows
  • Loading branch information
zhreshold authored and larroy committed Sep 28, 2019
1 parent 17e7c7f commit 4de39b3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
50 changes: 40 additions & 10 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import io
import sys
import signal
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
Expand Down Expand Up @@ -426,7 +427,8 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, data_loader=None):
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None,
data_loader=None, timeout=120):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -439,6 +441,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._pin_device_id = pin_device_id
self._dataset = dataset
self._data_loader = data_loader
self._timeout = timeout
# pre-fetch
for _ in range(prefetch):
self._push_next()
Expand All @@ -465,12 +468,29 @@ def __next__(self):
assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
try:
if self._dataset is None:
batch = pickle.loads(ret.get(self._timeout))
else:
batch = ret.get(self._timeout)
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
except multiprocessing.context.TimeoutError:
msg = '''Worker timed out after {} seconds. This might be caused by \n
- Slow transform. Please increase timeout to allow slower data loading in each worker.
'''.format(self._timeout)
if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool):
msg += '''- Insufficient shared_memory if `timeout` is large enough.
Please consider reduce `num_workers` or increase shared_memory in system.
'''
print(msg)
raise
except Exception:
self._worker_pool.terminate()
raise

def next(self):
return self.__next__()
Expand Down Expand Up @@ -537,16 +557,22 @@ def default_batchify_fn(data):
If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
problem, threadpool version may achieve better performance than multiprocessing.
timeout : int, default is 120
The timeout in seconds for each worker to fetch a batch data. Only modify this number
unless you are experiencing timeout and you know it's due to slow data loading.
Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
cases please reduce `num_workers` or increase system `shared_memory` size instead.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False):
prefetch=None, thread_pool=False, timeout=120):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._thread_pool = thread_pool
self._timeout = timeout
assert timeout > 0, "timeout must be positive, given {}".format(timeout)

if batch_sampler is None:
if batch_size is None:
Expand Down Expand Up @@ -577,9 +603,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
initializer=_thread_worker_initializer,
initargs=(is_np_shape(), is_np_array()))
else:
# set ignore keyboard interupt signal before forking processes
original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer,
initargs=[self._dataset, is_np_shape(), is_np_array()])
# resume keyboard interupt signal in main process
signal.signal(signal.SIGINT, original_sigint_handler)
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
Expand All @@ -604,7 +634,7 @@ def same_process_iter():
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None,
data_loader=self)
data_loader=self, timeout=self._timeout)

def __len__(self):
return len(self._batch_sampler)
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def test_multi_worker_forked_data_loader():
@with_seed()
def test_multi_worker_dataloader_release_pool():
# will trigger too many open file if pool is not released properly
for _ in range(100):
if os.name == 'nt':
print('Skip for windows since spawn on windows is too expensive.')
return
for _ in range(10):
A = np.random.rand(999, 2000)
D = mx.gluon.data.DataLoader(A, batch_size=8, num_workers=8)
the_iter = iter(D)
Expand Down

0 comments on commit 4de39b3

Please sign in to comment.