diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 9f0939ec7f37..10de86099d04 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -24,6 +24,7 @@ import pickle import io import sys +import signal import multiprocessing import multiprocessing.queues from multiprocessing.reduction import ForkingPickler @@ -426,7 +427,8 @@ def _thread_worker_fn(samples, batchify_fn, dataset): class _MultiWorkerIter(object): """Internal multi-worker iterator for DataLoader.""" def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, - pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, data_loader=None): + pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, + data_loader=None, timeout=120): self._worker_pool = worker_pool self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler @@ -439,6 +441,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, self._pin_device_id = pin_device_id self._dataset = dataset self._data_loader = data_loader + self._timeout = timeout # pre-fetch for _ in range(prefetch): self._push_next() @@ -465,12 +468,29 @@ def __next__(self): assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx" assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing" ret = self._data_buffer.pop(self._rcvd_idx) - batch = pickle.loads(ret.get()) if self._dataset is None else ret.get() - if self._pin_memory: - batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id)) - batch = batch[0] if len(batch) == 1 else batch - self._rcvd_idx += 1 - return batch + try: + if self._dataset is None: + batch = pickle.loads(ret.get(self._timeout)) + else: + batch = ret.get(self._timeout) + if self._pin_memory: + batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id)) + batch = batch[0] if len(batch) == 1 else batch + self._rcvd_idx += 1 + return batch + except multiprocessing.context.TimeoutError: + msg = '''Worker timed out after {} seconds. This might be caused by \n + - Slow transform. Please increase timeout to allow slower data loading in each worker. + '''.format(self._timeout) + if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool): + msg += '''- Insufficient shared_memory if `timeout` is large enough. + Please consider reduce `num_workers` or increase shared_memory in system. + ''' + print(msg) + raise + except Exception: + self._worker_pool.terminate() + raise def next(self): return self.__next__() @@ -537,16 +557,21 @@ def default_batchify_fn(data): If ``True``, use threading pool instead of multiprocessing pool. Using threadpool can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing problem, threadpool version may achieve better performance than multiprocessing. - + timeout : int, default is 120 + The timeout in seconds for each worker to fetch a batch data. Only modify this number + unless you are experiencing timeout and you know it's due to slow data loading. + Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these + cases please reduce `num_workers` or increase system `shared_memory` size instead. """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0, pin_memory=False, pin_device_id=0, - prefetch=None, thread_pool=False): + prefetch=None, thread_pool=False, timeout=120): self._dataset = dataset self._pin_memory = pin_memory self._pin_device_id = pin_device_id self._thread_pool = thread_pool + self._timeout = timeout if batch_sampler is None: if batch_size is None: @@ -577,9 +602,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, initializer=_thread_worker_initializer, initargs=(is_np_shape(), is_np_array())) else: + # set ignore keyboard interupt signal before forking processes + original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) self._worker_pool = multiprocessing.Pool( self._num_workers, initializer=_worker_initializer, initargs=[self._dataset, is_np_shape(), is_np_array()]) + # resume keyboard interupt signal in main process + signal.signal(signal.SIGINT, original_sigint_handler) if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn @@ -604,7 +633,7 @@ def same_process_iter(): worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn, prefetch=self._prefetch, dataset=self._dataset if self._thread_pool else None, - data_loader=self) + data_loader=self, timeout=self._timeout) def __len__(self): return len(self._batch_sampler)