Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Rewrite dataloader with process pool, improves responsiveness and rel…
Browse files Browse the repository at this point in the history
…iability (#13447)

* fix recordio.py

* rewrite dataloader with pool

* fix batch as tuple

* fix prefetching

* fix pylint

* picklable function

* use pickle

* add missing commit
  • Loading branch information
zhreshold authored and eric-haibin-lin committed Nov 30, 2018
1 parent b5ea194 commit 883d771
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 31 deletions.
223 changes: 192 additions & 31 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

from . import sampler as _sampler
from ... import nd, context
from ...recordio import MXRecordIO

if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_ndarray(*args):
Expand Down Expand Up @@ -159,37 +158,17 @@ def _as_in_context(data, ctx):
return [_as_in_context(d, ctx) for d in data]
return data

def _recursive_fork_recordio(obj, depth, max_depth=1000):
"""Recursively find instance of MXRecordIO and reset file handler.
This is required for MXRecordIO which holds a C pointer to a opened file after fork.
"""
if depth >= max_depth:
return
if isinstance(obj, MXRecordIO):
obj.close()
obj.open() # re-obtain file hanlder in new process
elif (hasattr(obj, '__dict__')):
for _, v in obj.__dict__.items():
_recursive_fork_recordio(v, depth + 1, max_depth)

def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
# re-fork a new recordio handler in new process if applicable
# for a dataset with transform function, the depth of MXRecordIO is 1
# for a lazy transformer, the depth is 2
# for a user defined transformer, the depth is unknown, try a reasonable depth
limit = sys.getrecursionlimit()
max_recursion_depth = min(limit - 5, max(10, limit // 2))
_recursive_fork_recordio(dataset, 0, max_recursion_depth)

def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
while True:
idx, samples = key_queue.get()
if idx is None:
break
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))

def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
Expand All @@ -206,10 +185,10 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non
data_buffer[idx] = batch


class _MultiWorkerIter(object):
"""Interal multi-worker iterator for DataLoader."""
class _MultiWorkerIterV1(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=worker_loop):
worker_fn=worker_loop_v1):
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
self._num_workers = num_workers
self._dataset = dataset
Expand Down Expand Up @@ -237,7 +216,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=
self._workers = workers

self._fetcher = threading.Thread(
target=fetcher_loop,
target=fetcher_loop_v1,
args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
self._fetcher.daemon = True
self._fetcher.start()
Expand Down Expand Up @@ -299,7 +278,7 @@ def shutdown(self):
self._shutdown = True


class DataLoader(object):
class DataLoaderV1(object):
"""Loads data from a dataset and returns mini-batches of data.
Parameters
Expand Down Expand Up @@ -390,8 +369,190 @@ def same_process_iter():
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler, self._pin_memory)
return _MultiWorkerIterV1(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler, self._pin_memory)

def __len__(self):
return len(self._batch_sampler)

_worker_dataset = None
def _worker_initializer(dataset):
"""Initialier for processing pool."""
# global dataset is per-process based and only available in worker processes
# this is only necessary to handle MXIndexedRecordIO because otherwise dataset
# can be passed as argument
global _worker_dataset
_worker_dataset = dataset

def _worker_fn(samples, batchify_fn):
"""Function for processing data in worker process."""
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
global _worker_dataset
batch = batchify_fn([_worker_dataset[i] for i in samples])
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
return buf.getvalue()

class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=_worker_fn, prefetch=0):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
self._data_buffer = {}
self._rcvd_idx = 0
self._sent_idx = 0
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
# pre-fetch
for _ in range(prefetch):
self._push_next()

def __len__(self):
return len(self._batch_sampler)

def _push_next(self):
"""Assign next batch workload to workers."""
r = next(self._iter, None)
if r is None:
return
async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1

def __next__(self):
self._push_next()
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, "Data buffer should be empty at this moment"
raise StopIteration

assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get())
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch

def next(self):
return self.__next__()

def __iter__(self):
return self


class DataLoader(object):
"""Loads data from a dataset and returns mini-batches of data.
Parameters
----------
dataset : Dataset
Source dataset. Note that numpy and mxnet arrays can be directly used
as a Dataset.
batch_size : int
Size of mini-batch.
shuffle : bool
Whether to shuffle the samples.
sampler : Sampler
The sampler to use. Either specify sampler or shuffle, not both.
last_batch : {'keep', 'discard', 'rollover'}
How to handle the last batch if batch_size does not evenly divide
`len(dataset)`.
keep - A batch with less samples than previous batches is returned.
discard - The last batch is discarded if its incomplete.
rollover - The remaining samples are rolled over to the next epoch.
batch_sampler : Sampler
A sampler that returns mini-batches. Do not specify batch_size,
shuffle, sampler, and last_batch if batch_sampler is specified.
batchify_fn : callable
Callback function to allow users to specify how to merge samples
into a batch. Defaults to `default_batchify_fn`::
def default_batchify_fn(data):
if isinstance(data[0], nd.NDArray):
return nd.stack(*data)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
return nd.array(data, dtype=data.dtype)
num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
pin_memory : boolean, default False
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, prefetch=None):
self._dataset = dataset
self._pin_memory = pin_memory

if batch_sampler is None:
if batch_size is None:
raise ValueError("batch_size must be specified unless " \
"batch_sampler is specified")
if sampler is None:
if shuffle:
sampler = _sampler.RandomSampler(len(dataset))
else:
sampler = _sampler.SequentialSampler(len(dataset))
elif shuffle:
raise ValueError("shuffle must not be specified if sampler is specified")

batch_sampler = _sampler.BatchSampler(
sampler, batch_size, last_batch if last_batch else 'keep')
elif batch_size is not None or shuffle or sampler is not None or \
last_batch is not None:
raise ValueError("batch_size, shuffle, sampler and last_batch must " \
"not be specified if batch_sampler is specified.")

self._batch_sampler = batch_sampler
self._num_workers = num_workers if num_workers >= 0 else 0
self._worker_pool = None
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
if self._num_workers > 0:
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
else:
self._batchify_fn = default_batchify_fn
else:
self._batchify_fn = batchify_fn

def __iter__(self):
if self._num_workers == 0:
def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory, worker_fn=_worker_fn,
prefetch=self._prefetch)

def __len__(self):
return len(self._batch_sampler)
17 changes: 17 additions & 0 deletions python/mxnet/recordio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Read and write for the RecordIO data format."""
from __future__ import absolute_import
from collections import namedtuple
from multiprocessing import current_process

import ctypes
import struct
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, uri, flag):
self.uri = c_str(uri)
self.handle = RecordIOHandle()
self.flag = flag
self.pid = None
self.is_open = False
self.open()

Expand All @@ -78,6 +80,7 @@ def open(self):
self.writable = False
else:
raise ValueError("Invalid flag %s"%self.flag)
self.pid = current_process().pid
self.is_open = True

def __del__(self):
Expand Down Expand Up @@ -109,6 +112,14 @@ def __setstate__(self, d):
if is_open:
self.open()

def _check_pid(self, allow_reset=False):
"""Check process id to ensure integrity, reset if in new process."""
if not self.pid == current_process().pid:
if allow_reset:
self.reset()
else:
raise RuntimeError("Forbidden operation in multiple processes")

def close(self):
"""Closes the record file."""
if not self.is_open:
Expand All @@ -118,6 +129,7 @@ def close(self):
else:
check_call(_LIB.MXRecordIOReaderFree(self.handle))
self.is_open = False
self.pid = None

def reset(self):
"""Resets the pointer to first item.
Expand Down Expand Up @@ -156,6 +168,7 @@ def write(self, buf):
Buffer to write.
"""
assert self.writable
self._check_pid(allow_reset=False)
check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
ctypes.c_char_p(buf),
ctypes.c_size_t(len(buf))))
Expand All @@ -182,6 +195,9 @@ def read(self):
Buffer read.
"""
assert not self.writable
# trying to implicitly read from multiple processes is forbidden,
# there's no elegant way to handle unless lock is introduced
self._check_pid(allow_reset=False)
buf = ctypes.c_char_p()
size = ctypes.c_size_t()
check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
Expand Down Expand Up @@ -255,6 +271,7 @@ def seek(self, idx):
This function is internally called by `read_idx(idx)` to find the current
reader pointer position. It doesn't return anything."""
assert not self.writable
self._check_pid(allow_reset=True)
pos = ctypes.c_size_t(self.idx[idx])
check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))

Expand Down

0 comments on commit 883d771

Please sign in to comment.