Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
improve dataloader signals and messages
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Sep 6, 2019
1 parent d60be31 commit ff4c092
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import io
import sys
import signal
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
Expand Down Expand Up @@ -426,7 +427,8 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, data_loader=None):
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None,
data_loader=None, timeout=120):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -439,6 +441,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._pin_device_id = pin_device_id
self._dataset = dataset
self._data_loader = data_loader
self._timeout = timeout
# pre-fetch
for _ in range(prefetch):
self._push_next()
Expand All @@ -465,12 +468,29 @@ def __next__(self):
assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
try:
if self._dataset is None:
batch = pickle.loads(ret.get(self._timeout))
else:
batch = ret.get(self._timeout)
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
except multiprocessing.context.TimeoutError:
msg = '''Worker timed out after {} seconds. This might be caused by \n
- Slow transform. Please increase timeout to allow slower data loading in each worker.
'''.format(self._timeout)
if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool):
msg += '''- Insufficient shared_memory if `timeout` is large enough.
Please consider reduce `num_workers` or increase shared_memory in system.
'''
print(msg)
raise
except Exception:
self._worker_pool.terminate()
raise

def next(self):
return self.__next__()
Expand Down Expand Up @@ -537,16 +557,21 @@ def default_batchify_fn(data):
If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
problem, threadpool version may achieve better performance than multiprocessing.
timeout : int, default is 120
The timeout in seconds for each worker to fetch a batch data. Only modify this number
unless you are experiencing timeout and you know it's due to slow data loading.
Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
cases please reduce `num_workers` or increase system `shared_memory` size instead.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False):
prefetch=None, thread_pool=False, timeout=120):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._thread_pool = thread_pool
self._timeout = timeout

if batch_sampler is None:
if batch_size is None:
Expand Down Expand Up @@ -577,9 +602,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
initializer=_thread_worker_initializer,
initargs=(is_np_shape(), is_np_array()))
else:
# set ignore keyboard interupt signal before forking processes
original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer,
initargs=[self._dataset, is_np_shape(), is_np_array()])
# resume keyboard interupt signal in main process
signal.signal(signal.SIGINT, original_sigint_handler)
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
Expand All @@ -604,7 +633,7 @@ def same_process_iter():
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None,
data_loader=self)
data_loader=self, timeout=self._timeout)

def __len__(self):
return len(self._batch_sampler)
Expand Down

0 comments on commit ff4c092

Please sign in to comment.