From 207c6c1dddaa19587b8809e17475e9707e946c97 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Fri, 15 Mar 2019 10:16:25 +0800 Subject: [PATCH] Support multi-threading for Custom Operator (#14363) * Support multi-threading for Custom Operator * update * Update custom-inl.h --- docs/faq/env_var.md | 3 ++ src/operator/custom/custom-inl.h | 51 +++++++++++++++++++++----------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 095c214e66b3..2768f644c066 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -60,6 +60,9 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 * MXNET_MP_OPENCV_NUM_THREADS - Values: Int ```(default=0)``` - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows). +* MXNET_CUSTOM_OP_NUM_THREADS + - Values: Int ```(default=16)``` + - The maximum number of threads given to custom operators. ## Memory Options diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h index de82f3e8eef0..f88e830bc573 100644 --- a/src/operator/custom/custom-inl.h +++ b/src/operator/custom/custom-inl.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -129,6 +130,9 @@ class CustomOperator { ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0, "CustomOperator"); }); + // increase num_threads if there is not enough threads to execute custom operator + if (q_.size() > num_free_threads) + CreateThreads(q_.size() - num_free_threads); cv_.notify_all(); } @@ -139,38 +143,51 @@ class CustomOperator { destructing_ = true; cv_.notify_all(); } - worker_.join(); + for (auto &worker : workers_) + worker.join(); } static CustomOperator* Get(); private: - CustomOperator() { + CustomOperator() : num_free_threads(0) { destructing_ = false; naive_engine_ = true; if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) { naive_engine_ = false; - worker_ = std::thread( - [&]() { - std::unique_lock lock(mutex_); - while (!q_.empty() || !destructing_) { - cv_.wait(lock, [&] {return !q_.empty() || destructing_;}); - while (!q_.empty()) { - auto fn = q_.front(); - lock.unlock(); - fn(); - lock.lock(); - q_.pop(); - } - } - }); } } + void ThreadTarget() { + std::unique_lock lock(mutex_); + while (!q_.empty() || !destructing_) { + cv_.wait(lock, [&] {return !q_.empty() || destructing_;}); + while (!q_.empty()) { + --num_free_threads; + auto fn = q_.front(); + q_.pop(); + lock.unlock(); + fn(); + ++num_free_threads; + lock.lock(); + } + } + } + void SetNumThreads(int num_threads) { + num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads); + for (int i = workers_.size(); i < num_threads; ++i) { + workers_.emplace_back(std::thread([this]{this->ThreadTarget();})); + ++num_free_threads; + } + } + void CreateThreads(int num_new_threads) { + SetNumThreads(workers_.size() + num_new_threads); + } std::mutex mutex_; std::map registry_; // async worker std::condition_variable cv_; - std::thread worker_; + std::vector workers_; + std::atomic num_free_threads; std::queue > q_; bool naive_engine_; bool destructing_;