From 23e40b9c5274a204d3bf485536ed39f075f14a56 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 10 Dec 2018 16:47:34 -0800 Subject: [PATCH 1/5] init --- docs/faq/env_var.md | 21 ++++++++++++------ python/mxnet/gluon/data/dataloader.py | 31 +++++++++++++++++++-------- src/initialize.cc | 15 ++++++++----- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 8d08e320721a..888b7082b98f 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -37,6 +37,15 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_CPU_NNPACK_NTHREADS - Values: Int ```(default=4)``` - The number of threads used for NNPACK. NNPACK package aims to provide high-performance implementations of some layers for multi-core CPUs. Checkout [NNPACK](http://mxnet.io/faq/nnpack.html) to know more about it. +* MXNET_MP_WORKER_NTHREADS + - Values: Int ```(default=1)``` + - The number of scheduling threads on CPU given to multiprocess workers. Enlarge this number allows more operators to run in parallel in individual workers but please consider reduce the overall `num_workers` to avoid thread contention (Not available on Windows). +* MXNET_MP_OMP_NUM_THREADS + - Values: Int ```(default=1)``` + - The number of OpenMP threads limit given to multiprocess workers. OpenMP is disabled in worker process if `MXNET_MP_OMP_NUM_THREADS` <= 1 (default). Enlarge this number may boost operator execution performance of individual workers but please consider reduce the overall `num_workers` to avoid thread contention (Not available on Windows). +* MXNET_MP_OPENCV_NUM_THREADS + - Values: Int ```(default=0)``` + - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reduce the overall `num_workers` to avoid thread contention (Not available on Windows). ## Memory Options @@ -99,10 +108,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_KVSTORE_REDUCTION_NTHREADS - Values: Int ```(default=4)``` - The number of CPU threads used for summing up big arrays on a single machine - - This will also be used for `dist_sync` kvstore to sum up arrays from different contexts on a single machine. - - This does not affect summing up of arrays from different machines on servers. + - This will also be used for `dist_sync` kvstore to sum up arrays from different contexts on a single machine. + - This does not affect summing up of arrays from different machines on servers. - Summing up of arrays for `dist_sync_device` kvstore is also unaffected as that happens on GPUs. - + * MXNET_KVSTORE_BIGARRAY_BOUND - Values: Int ```(default=1000000)``` - The minimum size of a "big array". @@ -166,7 +175,7 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca * MXNET_CUDNN_AUTOTUNE_DEFAULT - Values: 0, 1, or 2 ```(default=1)``` - - The default value of cudnn auto tuning for convolution layers. + - The default value of cudnn auto tuning for convolution layers. - Value of 0 means there is no auto tuning to pick the convolution algo - Performance tests are run to pick the convolution algo when value is 1 or 2 - Value of 1 chooses the best algo in a limited workspace @@ -190,12 +199,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca * MXNET_HOME - Data directory in the filesystem for storage, for example when downloading gluon models. - Default in *nix is .mxnet APPDATA/mxnet in windows. - + * MXNET_MKLDNN_ENABLED - Values: 0, 1 ```(default=1)``` - Flag to enable or disable MKLDNN accelerator. On by default. - Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```) - + * MXNET_MKLDNN_CACHE_NUM - Values: Int ```(default=-1)``` - Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape. diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 586e620470d3..7388123149cf 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -26,6 +26,7 @@ import multiprocessing import multiprocessing.queues from multiprocessing.reduction import ForkingPickler +from multiprocessing.pool import ThreadPool import threading import numpy as np @@ -384,7 +385,7 @@ def _worker_initializer(dataset): global _worker_dataset _worker_dataset = dataset -def _worker_fn(samples, batchify_fn): +def _worker_fn(samples, batchify_fn, dataset=None): """Function for processing data in worker process.""" # it is required that each worker process has to fork a new MXIndexedRecordIO handle # preserving dataset as global variable can save tons of overhead and is safe in new process @@ -394,10 +395,14 @@ def _worker_fn(samples, batchify_fn): ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch) return buf.getvalue() +def _thread_worker_fn(samples, batchify_fn, dataset): + """Threadpool worker function for processing data.""" + return batchify_fn([dataset[i] for i in samples]) + class _MultiWorkerIter(object): """Internal multi-worker iterator for DataLoader.""" def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, - worker_fn=_worker_fn, prefetch=0): + worker_fn=_worker_fn, prefetch=0, dataset=None): self._worker_pool = worker_pool self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler @@ -407,6 +412,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, self._iter = iter(self._batch_sampler) self._worker_fn = worker_fn self._pin_memory = pin_memory + self._dataset = dataset # pre-fetch for _ in range(prefetch): self._push_next() @@ -419,7 +425,8 @@ def _push_next(self): r = next(self._iter, None) if r is None: return - async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn)) + async_ret = self._worker_pool.apply_async( + self._worker_fn, (r, self._batchify_fn, self._dataset)) self._data_buffer[self._sent_idx] = async_ret self._sent_idx += 1 @@ -432,7 +439,7 @@ def __next__(self): assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx" assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing" ret = self._data_buffer.pop(self._rcvd_idx) - batch = pickle.loads(ret.get()) + batch = pickle.loads(ret.get()) if self._dataset is None else ret.get() if self._pin_memory: batch = _as_in_context(batch, context.cpu_pinned()) batch = batch[0] if len(batch) == 1 else batch @@ -501,9 +508,10 @@ def default_batchify_fn(data): """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, - num_workers=0, pin_memory=False, prefetch=None): + num_workers=0, pin_memory=False, prefetch=None, thread_pool=False): self._dataset = dataset self._pin_memory = pin_memory + self._thread_pool = thread_pool if batch_sampler is None: if batch_size is None: @@ -529,8 +537,11 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, self._worker_pool = None self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers) if self._num_workers > 0: - self._worker_pool = multiprocessing.Pool( - self._num_workers, initializer=_worker_initializer, initargs=[self._dataset]) + if self._thread_pool: + self._worker_pool = ThreadPool(self._num_workers) + else: + self._worker_pool = multiprocessing.Pool( + self._num_workers, initializer=_worker_initializer, initargs=[self._dataset]) if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn @@ -551,8 +562,10 @@ def same_process_iter(): # multi-worker return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler, - pin_memory=self._pin_memory, worker_fn=_worker_fn, - prefetch=self._prefetch) + pin_memory=self._pin_memory, + worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn, + prefetch=self._prefetch, + dataset=self._dataset if self._thread_pool else None) def __len__(self): return len(self._batch_sampler) diff --git a/src/initialize.cc b/src/initialize.cc index ddda3f18a3ae..921169a6ebf9 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -57,13 +57,18 @@ class LibraryInitializer { Engine::Get()->Start(); }, []() { - // Make children single threaded since they are typically workers - dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", 1); - dmlc::SetEnv("OMP_NUM_THREADS", 1); + // Conservative thread management for multiprocess workers + const size_t mp_worker_threads = dmlc::GetEnv("MXNET_MP_WORKER_NTHREADS", 1); + const size_t mp_omp_threads = dmlc::GetEnv("MXNET_MP_OMP_NUM_THREADS", 1); + dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", mp_worker_threads); + dmlc::SetEnv("OMP_NUM_THREADS", mp_omp_threads); #if MXNET_USE_OPENCV && !__APPLE__ - cv::setNumThreads(0); // disable opencv threading + const size_t mp_cv_num_threads = dmlc::GetEnv("MXNET_MP_OPENCV_NUM_THREADS", 0); + cv::setNumThreads(mp_cv_num_threads); // disable opencv threading #endif // MXNET_USE_OPENCV - engine::OpenMP::Get()->set_enabled(false); + if (mp_omp_threads <= 1) { + engine::OpenMP::Get()->set_enabled(false); + } Engine::Get()->Start(); }); #endif From aadc16a3fa77ef7bb7273f59e8f883bf480b9f09 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 10 Dec 2018 17:27:22 -0800 Subject: [PATCH 2/5] add tests --- python/mxnet/gluon/data/dataloader.py | 8 +++++++- tests/python/unittest/test_gluon_data.py | 7 ++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 7388123149cf..a3fa909f0c6e 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -505,6 +505,11 @@ def default_batchify_fn(data): but will consume more shared_memory. Using smaller number may forfeit the purpose of using multiple worker processes, try reduce `num_workers` in this case. By default it defaults to `num_workers * 2`. + thread_pool : bool, default False + If ``True``, use threading pool instead of multiprocessing pool. Using threadpool + can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing + problem, threadpool version may achieve better performance than multiprocessing. + """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, @@ -572,6 +577,7 @@ def __len__(self): def __del__(self): if self._worker_pool: - # manually terminate due to a bug that pool is not automatically terminated on linux + # manually terminate due to a bug that pool is not automatically terminated + # https://bugs.python.org/issue34172 assert isinstance(self._worker_pool, multiprocessing.pool.Pool) self._worker_pool.terminate() diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index a3ba222c71d8..6a5322616e20 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -156,9 +156,10 @@ def __getitem__(self, key): @with_seed() def test_multi_worker(): data = Dataset() - loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5) - for i, batch in enumerate(loader): - assert (batch.asnumpy() == i).all() + for thread_pool in [True, False]: + loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5, thread_pool=thread_pool) + for i, batch in enumerate(loader): + assert (batch.asnumpy() == i).all() class _Dummy(Dataset): """Dummy dataset for randomized shape arrays.""" From cec93c166286a14873643bbfe95ef387dc0e1c08 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 10 Dec 2018 17:35:43 -0800 Subject: [PATCH 3/5] doc --- docs/faq/env_var.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 888b7082b98f..d1a035b95803 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -39,13 +39,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - The number of threads used for NNPACK. NNPACK package aims to provide high-performance implementations of some layers for multi-core CPUs. Checkout [NNPACK](http://mxnet.io/faq/nnpack.html) to know more about it. * MXNET_MP_WORKER_NTHREADS - Values: Int ```(default=1)``` - - The number of scheduling threads on CPU given to multiprocess workers. Enlarge this number allows more operators to run in parallel in individual workers but please consider reduce the overall `num_workers` to avoid thread contention (Not available on Windows). + - The number of scheduling threads on CPU given to multiprocess workers. Enlarge this number allows more operators to run in parallel in individual workers but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). * MXNET_MP_OMP_NUM_THREADS - Values: Int ```(default=1)``` - - The number of OpenMP threads limit given to multiprocess workers. OpenMP is disabled in worker process if `MXNET_MP_OMP_NUM_THREADS` <= 1 (default). Enlarge this number may boost operator execution performance of individual workers but please consider reduce the overall `num_workers` to avoid thread contention (Not available on Windows). + - The number of OpenMP threads limit given to multiprocess workers. OpenMP is disabled in worker process if `MXNET_MP_OMP_NUM_THREADS` <= 1 (default). Enlarge this number may boost operator execution performance of individual workers but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). * MXNET_MP_OPENCV_NUM_THREADS - Values: Int ```(default=0)``` - - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reduce the overall `num_workers` to avoid thread contention (Not available on Windows). + - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). ## Memory Options From 229b8fb732fdbdc96f2dec27815b332d6b856dda Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Mon, 10 Dec 2018 17:44:26 -0800 Subject: [PATCH 4/5] lint --- python/mxnet/gluon/data/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index a3fa909f0c6e..9d762745a407 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -387,6 +387,7 @@ def _worker_initializer(dataset): def _worker_fn(samples, batchify_fn, dataset=None): """Function for processing data in worker process.""" + # pylint: disable=unused-argument # it is required that each worker process has to fork a new MXIndexedRecordIO handle # preserving dataset as global variable can save tons of overhead and is safe in new process global _worker_dataset From df94ac83742a9c123efec3e5dca8a87a6c445d43 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Wed, 12 Dec 2018 16:17:10 -0800 Subject: [PATCH 5/5] fix openmp --- docs/faq/env_var.md | 3 --- src/initialize.cc | 7 ++----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index d1a035b95803..b1f5014ff162 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -40,9 +40,6 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_MP_WORKER_NTHREADS - Values: Int ```(default=1)``` - The number of scheduling threads on CPU given to multiprocess workers. Enlarge this number allows more operators to run in parallel in individual workers but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). -* MXNET_MP_OMP_NUM_THREADS - - Values: Int ```(default=1)``` - - The number of OpenMP threads limit given to multiprocess workers. OpenMP is disabled in worker process if `MXNET_MP_OMP_NUM_THREADS` <= 1 (default). Enlarge this number may boost operator execution performance of individual workers but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). * MXNET_MP_OPENCV_NUM_THREADS - Values: Int ```(default=0)``` - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). diff --git a/src/initialize.cc b/src/initialize.cc index 921169a6ebf9..de7edd1b1455 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -59,16 +59,13 @@ class LibraryInitializer { []() { // Conservative thread management for multiprocess workers const size_t mp_worker_threads = dmlc::GetEnv("MXNET_MP_WORKER_NTHREADS", 1); - const size_t mp_omp_threads = dmlc::GetEnv("MXNET_MP_OMP_NUM_THREADS", 1); dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", mp_worker_threads); - dmlc::SetEnv("OMP_NUM_THREADS", mp_omp_threads); + dmlc::SetEnv("OMP_NUM_THREADS", 1); #if MXNET_USE_OPENCV && !__APPLE__ const size_t mp_cv_num_threads = dmlc::GetEnv("MXNET_MP_OPENCV_NUM_THREADS", 0); cv::setNumThreads(mp_cv_num_threads); // disable opencv threading #endif // MXNET_USE_OPENCV - if (mp_omp_threads <= 1) { - engine::OpenMP::Get()->set_enabled(false); - } + engine::OpenMP::Get()->set_enabled(false); Engine::Get()->Start(); }); #endif