Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

improve dataloader signals and messages #16114

Merged
merged 3 commits into from
Sep 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import io
import sys
import signal
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
Expand Down Expand Up @@ -426,7 +427,8 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, data_loader=None):
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None,
data_loader=None, timeout=120):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -439,6 +441,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._pin_device_id = pin_device_id
self._dataset = dataset
self._data_loader = data_loader
self._timeout = timeout
# pre-fetch
for _ in range(prefetch):
self._push_next()
Expand All @@ -465,12 +468,29 @@ def __next__(self):
assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
try:
if self._dataset is None:
batch = pickle.loads(ret.get(self._timeout))
else:
batch = ret.get(self._timeout)
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
except multiprocessing.context.TimeoutError:
msg = '''Worker timed out after {} seconds. This might be caused by \n
- Slow transform. Please increase timeout to allow slower data loading in each worker.
'''.format(self._timeout)
if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool):
msg += '''- Insufficient shared_memory if `timeout` is large enough.
Please consider reduce `num_workers` or increase shared_memory in system.
'''
print(msg)
raise
except Exception:
self._worker_pool.terminate()
raise

def next(self):
return self.__next__()
Expand Down Expand Up @@ -537,16 +557,22 @@ def default_batchify_fn(data):
If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
problem, threadpool version may achieve better performance than multiprocessing.

timeout : int, default is 120
The timeout in seconds for each worker to fetch a batch data. Only modify this number
unless you are experiencing timeout and you know it's due to slow data loading.
Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
cases please reduce `num_workers` or increase system `shared_memory` size instead.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False):
prefetch=None, thread_pool=False, timeout=120):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._thread_pool = thread_pool
self._timeout = timeout
szha marked this conversation as resolved.
Show resolved Hide resolved
assert timeout > 0, "timeout must be positive, given {}".format(timeout)

if batch_sampler is None:
if batch_size is None:
Expand Down Expand Up @@ -577,9 +603,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
initializer=_thread_worker_initializer,
initargs=(is_np_shape(), is_np_array()))
else:
# set ignore keyboard interupt signal before forking processes
original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer,
initargs=[self._dataset, is_np_shape(), is_np_array()])
# resume keyboard interupt signal in main process
signal.signal(signal.SIGINT, original_sigint_handler)
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
Expand All @@ -604,7 +634,7 @@ def same_process_iter():
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None,
data_loader=self)
data_loader=self, timeout=self._timeout)

def __len__(self):
return len(self._batch_sampler)
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def test_multi_worker_forked_data_loader():
@with_seed()
def test_multi_worker_dataloader_release_pool():
# will trigger too many open file if pool is not released properly
for _ in range(100):
if os.name == 'nt':
print('Skip for windows since spawn on windows is too expensive.')
return
for _ in range(10):
A = np.random.rand(999, 2000)
D = mx.gluon.data.DataLoader(A, batch_size=8, num_workers=8)
the_iter = iter(D)
Expand Down