diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 86cb835f5128..ad0f534d16dd 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -36,7 +36,6 @@ from . import sampler as _sampler from ... import nd, context -from ...recordio import MXRecordIO if sys.platform == 'darwin' or sys.platform == 'win32': def rebuild_ndarray(*args): @@ -159,29 +158,9 @@ def _as_in_context(data, ctx): return [_as_in_context(d, ctx) for d in data] return data -def _recursive_fork_recordio(obj, depth, max_depth=1000): - """Recursively find instance of MXRecordIO and reset file handler. - This is required for MXRecordIO which holds a C pointer to a opened file after fork. - """ - if depth >= max_depth: - return - if isinstance(obj, MXRecordIO): - obj.close() - obj.open() # re-obtain file hanlder in new process - elif (hasattr(obj, '__dict__')): - for _, v in obj.__dict__.items(): - _recursive_fork_recordio(v, depth + 1, max_depth) - -def worker_loop(dataset, key_queue, data_queue, batchify_fn): - """Worker loop for multiprocessing DataLoader.""" - # re-fork a new recordio handler in new process if applicable - # for a dataset with transform function, the depth of MXRecordIO is 1 - # for a lazy transformer, the depth is 2 - # for a user defined transformer, the depth is unknown, try a reasonable depth - limit = sys.getrecursionlimit() - max_recursion_depth = min(limit - 5, max(10, limit // 2)) - _recursive_fork_recordio(dataset, 0, max_recursion_depth) +def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn): + """Worker loop for multiprocessing DataLoader.""" while True: idx, samples = key_queue.get() if idx is None: @@ -189,7 +168,7 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn): batch = batchify_fn([dataset[i] for i in samples]) data_queue.put((idx, batch)) -def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None): +def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None): """Fetcher loop for fetching data from queue and put in reorder dict.""" while True: idx, batch = data_queue.get() @@ -206,10 +185,10 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non data_buffer[idx] = batch -class _MultiWorkerIter(object): - """Interal multi-worker iterator for DataLoader.""" +class _MultiWorkerIterV1(object): + """Internal multi-worker iterator for DataLoader.""" def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, - worker_fn=worker_loop): + worker_fn=worker_loop_v1): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) self._num_workers = num_workers self._dataset = dataset @@ -237,7 +216,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= self._workers = workers self._fetcher = threading.Thread( - target=fetcher_loop, + target=fetcher_loop_v1, args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock)) self._fetcher.daemon = True self._fetcher.start() @@ -299,7 +278,7 @@ def shutdown(self): self._shutdown = True -class DataLoader(object): +class DataLoaderV1(object): """Loads data from a dataset and returns mini-batches of data. Parameters @@ -390,8 +369,190 @@ def same_process_iter(): return same_process_iter() # multi-worker - return _MultiWorkerIter(self._num_workers, self._dataset, - self._batchify_fn, self._batch_sampler, self._pin_memory) + return _MultiWorkerIterV1(self._num_workers, self._dataset, + self._batchify_fn, self._batch_sampler, self._pin_memory) + + def __len__(self): + return len(self._batch_sampler) + +_worker_dataset = None +def _worker_initializer(dataset): + """Initialier for processing pool.""" + # global dataset is per-process based and only available in worker processes + # this is only necessary to handle MXIndexedRecordIO because otherwise dataset + # can be passed as argument + global _worker_dataset + _worker_dataset = dataset + +def _worker_fn(samples, batchify_fn): + """Function for processing data in worker process.""" + # it is required that each worker process has to fork a new MXIndexedRecordIO handle + # preserving dataset as global variable can save tons of overhead and is safe in new process + global _worker_dataset + batch = batchify_fn([_worker_dataset[i] for i in samples]) + buf = io.BytesIO() + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch) + return buf.getvalue() + +class _MultiWorkerIter(object): + """Internal multi-worker iterator for DataLoader.""" + def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, + worker_fn=_worker_fn, prefetch=0): + self._worker_pool = worker_pool + self._batchify_fn = batchify_fn + self._batch_sampler = batch_sampler + self._data_buffer = {} + self._rcvd_idx = 0 + self._sent_idx = 0 + self._iter = iter(self._batch_sampler) + self._worker_fn = worker_fn + self._pin_memory = pin_memory + # pre-fetch + for _ in range(prefetch): + self._push_next() + + def __len__(self): + return len(self._batch_sampler) + + def _push_next(self): + """Assign next batch workload to workers.""" + r = next(self._iter, None) + if r is None: + return + async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn)) + self._data_buffer[self._sent_idx] = async_ret + self._sent_idx += 1 + + def __next__(self): + self._push_next() + if self._rcvd_idx == self._sent_idx: + assert not self._data_buffer, "Data buffer should be empty at this moment" + raise StopIteration + + assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx" + assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing" + ret = self._data_buffer.pop(self._rcvd_idx) + batch = pickle.loads(ret.get()) + if self._pin_memory: + batch = _as_in_context(batch, context.cpu_pinned()) + batch = batch[0] if len(batch) == 1 else batch + self._rcvd_idx += 1 + return batch + + def next(self): + return self.__next__() + + def __iter__(self): + return self + + +class DataLoader(object): + """Loads data from a dataset and returns mini-batches of data. + + Parameters + ---------- + dataset : Dataset + Source dataset. Note that numpy and mxnet arrays can be directly used + as a Dataset. + batch_size : int + Size of mini-batch. + shuffle : bool + Whether to shuffle the samples. + sampler : Sampler + The sampler to use. Either specify sampler or shuffle, not both. + last_batch : {'keep', 'discard', 'rollover'} + How to handle the last batch if batch_size does not evenly divide + `len(dataset)`. + + keep - A batch with less samples than previous batches is returned. + discard - The last batch is discarded if its incomplete. + rollover - The remaining samples are rolled over to the next epoch. + batch_sampler : Sampler + A sampler that returns mini-batches. Do not specify batch_size, + shuffle, sampler, and last_batch if batch_sampler is specified. + batchify_fn : callable + Callback function to allow users to specify how to merge samples + into a batch. Defaults to `default_batchify_fn`:: + + def default_batchify_fn(data): + if isinstance(data[0], nd.NDArray): + return nd.stack(*data) + elif isinstance(data[0], tuple): + data = zip(*data) + return [default_batchify_fn(i) for i in data] + else: + data = np.asarray(data) + return nd.array(data, dtype=data.dtype) + + num_workers : int, default 0 + The number of multiprocessing workers to use for data preprocessing. + pin_memory : boolean, default False + If ``True``, the dataloader will copy NDArrays into pinned memory + before returning them. Copying from CPU pinned memory to GPU is faster + than from normal CPU memory. + prefetch : int, default is `num_workers * 2` + The number of prefetching batches only works if `num_workers` > 0. + If `prefetch` > 0, it allow worker process to prefetch certain batches before + acquiring data from iterators. + Note that using large prefetching batch will provide smoother bootstrapping performance, + but will consume more shared_memory. Using smaller number may forfeit the purpose of using + multiple worker processes, try reduce `num_workers` in this case. + By default it defaults to `num_workers * 2`. + """ + def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, + last_batch=None, batch_sampler=None, batchify_fn=None, + num_workers=0, pin_memory=False, prefetch=None): + self._dataset = dataset + self._pin_memory = pin_memory + + if batch_sampler is None: + if batch_size is None: + raise ValueError("batch_size must be specified unless " \ + "batch_sampler is specified") + if sampler is None: + if shuffle: + sampler = _sampler.RandomSampler(len(dataset)) + else: + sampler = _sampler.SequentialSampler(len(dataset)) + elif shuffle: + raise ValueError("shuffle must not be specified if sampler is specified") + + batch_sampler = _sampler.BatchSampler( + sampler, batch_size, last_batch if last_batch else 'keep') + elif batch_size is not None or shuffle or sampler is not None or \ + last_batch is not None: + raise ValueError("batch_size, shuffle, sampler and last_batch must " \ + "not be specified if batch_sampler is specified.") + + self._batch_sampler = batch_sampler + self._num_workers = num_workers if num_workers >= 0 else 0 + self._worker_pool = None + self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers) + if self._num_workers > 0: + self._worker_pool = multiprocessing.Pool( + self._num_workers, initializer=_worker_initializer, initargs=[self._dataset]) + if batchify_fn is None: + if num_workers > 0: + self._batchify_fn = default_mp_batchify_fn + else: + self._batchify_fn = default_batchify_fn + else: + self._batchify_fn = batchify_fn + + def __iter__(self): + if self._num_workers == 0: + def same_process_iter(): + for batch in self._batch_sampler: + ret = self._batchify_fn([self._dataset[idx] for idx in batch]) + if self._pin_memory: + ret = _as_in_context(ret, context.cpu_pinned()) + yield ret + return same_process_iter() + + # multi-worker + return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler, + pin_memory=self._pin_memory, worker_fn=_worker_fn, + prefetch=self._prefetch) def __len__(self): return len(self._batch_sampler) diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 2def141c9340..bdc63235d702 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -18,6 +18,7 @@ """Read and write for the RecordIO data format.""" from __future__ import absolute_import from collections import namedtuple +from multiprocessing import current_process import ctypes import struct @@ -65,6 +66,7 @@ def __init__(self, uri, flag): self.uri = c_str(uri) self.handle = RecordIOHandle() self.flag = flag + self.pid = None self.is_open = False self.open() @@ -78,6 +80,7 @@ def open(self): self.writable = False else: raise ValueError("Invalid flag %s"%self.flag) + self.pid = current_process().pid self.is_open = True def __del__(self): @@ -109,6 +112,14 @@ def __setstate__(self, d): if is_open: self.open() + def _check_pid(self, allow_reset=False): + """Check process id to ensure integrity, reset if in new process.""" + if not self.pid == current_process().pid: + if allow_reset: + self.reset() + else: + raise RuntimeError("Forbidden operation in multiple processes") + def close(self): """Closes the record file.""" if not self.is_open: @@ -118,6 +129,7 @@ def close(self): else: check_call(_LIB.MXRecordIOReaderFree(self.handle)) self.is_open = False + self.pid = None def reset(self): """Resets the pointer to first item. @@ -156,6 +168,7 @@ def write(self, buf): Buffer to write. """ assert self.writable + self._check_pid(allow_reset=False) check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle, ctypes.c_char_p(buf), ctypes.c_size_t(len(buf)))) @@ -182,6 +195,9 @@ def read(self): Buffer read. """ assert not self.writable + # trying to implicitly read from multiple processes is forbidden, + # there's no elegant way to handle unless lock is introduced + self._check_pid(allow_reset=False) buf = ctypes.c_char_p() size = ctypes.c_size_t() check_call(_LIB.MXRecordIOReaderReadRecord(self.handle, @@ -255,6 +271,7 @@ def seek(self, idx): This function is internally called by `read_idx(idx)` to find the current reader pointer position. It doesn't return anything.""" assert not self.writable + self._check_pid(allow_reset=True) pos = ctypes.c_size_t(self.idx[idx]) check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))