Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Support multi-threading for Custom Operator (#14363)
Browse files Browse the repository at this point in the history
* Support multi-threading for Custom Operator

* update

* Update custom-inl.h
  • Loading branch information
wkcn authored and nswamy committed Apr 5, 2019
1 parent 2b315c1 commit 207c6c1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
3 changes: 3 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_MP_OPENCV_NUM_THREADS
- Values: Int ```(default=0)```
- The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
* MXNET_CUSTOM_OP_NUM_THREADS
- Values: Int ```(default=16)```
- The maximum number of threads given to custom operators.

## Memory Options

Expand Down
51 changes: 34 additions & 17 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <mxnet/operator.h>
#include <mxnet/c_api.h>
#include <mxnet/imperative.h>
#include <algorithm>
#include <map>
#include <vector>
#include <string>
Expand Down Expand Up @@ -129,6 +130,9 @@ class CustomOperator {
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
"CustomOperator");
});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads)
CreateThreads(q_.size() - num_free_threads);
cv_.notify_all();
}

Expand All @@ -139,38 +143,51 @@ class CustomOperator {
destructing_ = true;
cv_.notify_all();
}
worker_.join();
for (auto &worker : workers_)
worker.join();
}

static CustomOperator* Get();

private:
CustomOperator() {
CustomOperator() : num_free_threads(0) {
destructing_ = false;
naive_engine_ = true;
if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
naive_engine_ = false;
worker_ = std::thread(
[&]() {
std::unique_lock<std::mutex> lock(mutex_);
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
auto fn = q_.front();
lock.unlock();
fn();
lock.lock();
q_.pop();
}
}
});
}
}
void ThreadTarget() {
std::unique_lock<std::mutex> lock(mutex_);
while (!q_.empty() || !destructing_) {
cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
while (!q_.empty()) {
--num_free_threads;
auto fn = q_.front();
q_.pop();
lock.unlock();
fn();
++num_free_threads;
lock.lock();
}
}
}
void SetNumThreads(int num_threads) {
num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), num_threads);
for (int i = workers_.size(); i < num_threads; ++i) {
workers_.emplace_back(std::thread([this]{this->ThreadTarget();}));
++num_free_threads;
}
}
void CreateThreads(int num_new_threads) {
SetNumThreads(workers_.size() + num_new_threads);
}
std::mutex mutex_;
std::map<std::string, CustomOpPropCreator> registry_;
// async worker
std::condition_variable cv_;
std::thread worker_;
std::vector<std::thread> workers_;
std::atomic<uint32_t> num_free_threads;
std::queue<std::function<void(void)> > q_;
bool naive_engine_;
bool destructing_;
Expand Down

0 comments on commit 207c6c1

Please sign in to comment.