diff --git a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h index 56a036527a56b..2997ce1fe2473 100644 --- a/paddle/fluid/framework/new_executor/nonblocking_threadpool.h +++ b/paddle/fluid/framework/new_executor/nonblocking_threadpool.h @@ -73,6 +73,7 @@ class ThreadPoolTempl { allow_spinning_(allow_spinning), global_steal_partition_(EncodePartition(0, num_threads_)), blocked_(0), + num_tasks_(0), spinning_(0), done_(false), cancelled_(false), @@ -143,6 +144,7 @@ class ThreadPoolTempl { void AddTaskWithHint(std::function fn, int start, int limit) { Task t = env_.CreateTask(std::move(fn)); PerThread* pt = GetPerThread(); + uint64_t num_tasks = num_tasks_.fetch_add(1, std::memory_order_relaxed) + 1; if (pt->pool == this) { // Worker thread of this pool, push onto the thread's queue. Queue& q = thread_data_[pt->thread_id].queue; @@ -166,8 +168,11 @@ class ThreadPoolTempl { // this. We expect that such scenario is prevented by program, that is, // this is kept alive while any threads can potentially be in Schedule. if (!t.f) { - ec_.Notify(false); + if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) { + ec_.Notify(false); + } } else { + num_tasks_.fetch_sub(1, std::memory_order_relaxed); env_.ExecuteTask(t); // Push failed, execute directly. } } @@ -263,6 +268,7 @@ class ThreadPoolTempl { std::vector> all_coprimes_; unsigned global_steal_partition_; std::atomic blocked_; + std::atomic num_tasks_; std::atomic spinning_; std::atomic done_; std::atomic cancelled_; @@ -305,6 +311,7 @@ class ThreadPoolTempl { } if (t.f) { env_.ExecuteTask(t); + num_tasks_.fetch_sub(1, std::memory_order_relaxed); } } } else { @@ -315,8 +322,7 @@ class ThreadPoolTempl { if (!t.f) { t = GlobalSteal(); if (!t.f) { - // Leave one thread spinning. This reduces latency. - if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) { + if (allow_spinning_) { for (int i = 0; i < spin_count && !t.f; i++) { if (!cancelled_.load(std::memory_order_relaxed)) { t = GlobalSteal(); @@ -324,7 +330,6 @@ class ThreadPoolTempl { return; } } - spinning_ = false; } if (!t.f) { if (!WaitForWork(waiter, &t)) { @@ -336,6 +341,7 @@ class ThreadPoolTempl { } if (t.f) { env_.ExecuteTask(t); + num_tasks_.fetch_sub(1, std::memory_order_relaxed); } } }