Skip to content

Commit

Permalink
WorkQueue update
Browse files Browse the repository at this point in the history
  • Loading branch information
liutiexing committed Sep 22, 2021
1 parent 4759bc8 commit 14ce793
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions paddle/fluid/framework/new_executor/nonblocking_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ThreadPoolTempl {
allow_spinning_(allow_spinning),
global_steal_partition_(EncodePartition(0, num_threads_)),
blocked_(0),
num_tasks_(0),
spinning_(0),
done_(false),
cancelled_(false),
Expand Down Expand Up @@ -143,6 +144,7 @@ class ThreadPoolTempl {
void AddTaskWithHint(std::function<void()> fn, int start, int limit) {
Task t = env_.CreateTask(std::move(fn));
PerThread* pt = GetPerThread();
uint64_t num_tasks = num_tasks_.fetch_add(1, std::memory_order_relaxed) + 1;
if (pt->pool == this) {
// Worker thread of this pool, push onto the thread's queue.
Queue& q = thread_data_[pt->thread_id].queue;
Expand All @@ -166,8 +168,11 @@ class ThreadPoolTempl {
// this. We expect that such scenario is prevented by program, that is,
// this is kept alive while any threads can potentially be in Schedule.
if (!t.f) {
ec_.Notify(false);
if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) {
ec_.Notify(false);
}
} else {
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
env_.ExecuteTask(t); // Push failed, execute directly.
}
}
Expand Down Expand Up @@ -263,6 +268,7 @@ class ThreadPoolTempl {
std::vector<std::vector<unsigned>> all_coprimes_;
unsigned global_steal_partition_;
std::atomic<unsigned> blocked_;
std::atomic<uint64_t> num_tasks_;
std::atomic<bool> spinning_;
std::atomic<bool> done_;
std::atomic<bool> cancelled_;
Expand Down Expand Up @@ -305,6 +311,7 @@ class ThreadPoolTempl {
}
if (t.f) {
env_.ExecuteTask(t);
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
}
}
} else {
Expand All @@ -315,16 +322,14 @@ class ThreadPoolTempl {
if (!t.f) {
t = GlobalSteal();
if (!t.f) {
// Leave one thread spinning. This reduces latency.
if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) {
if (allow_spinning_) {
for (int i = 0; i < spin_count && !t.f; i++) {
if (!cancelled_.load(std::memory_order_relaxed)) {
t = GlobalSteal();
} else {
return;
}
}
spinning_ = false;
}
if (!t.f) {
if (!WaitForWork(waiter, &t)) {
Expand All @@ -336,6 +341,7 @@ class ThreadPoolTempl {
}
if (t.f) {
env_.ExecuteTask(t);
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
}
}
}
Expand Down

0 comments on commit 14ce793

Please sign in to comment.