diff --git a/include/onnxruntime/core/platform/Barrier.h b/include/onnxruntime/core/platform/Barrier.h index e44dacaa9c219..67cc29e3a3475 100644 --- a/include/onnxruntime/core/platform/Barrier.h +++ b/include/onnxruntime/core/platform/Barrier.h @@ -25,13 +25,14 @@ class Barrier { } #endif - void Notify() { - unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; + void Notify(unsigned int c = 1) { + unsigned int delta = c << 1; + unsigned int v = state_.fetch_sub(delta, std::memory_order_acq_rel) - delta; if (v != 1) { // Clear the lowest bit (waiter flag) and check that the original state // value was not zero. If it was zero, it means that notify was called // more times than the original count. - assert(((v + 2) & ~1) != 0); + assert(((v + delta) & ~1) != 0); return; // either count has not dropped to 0, or waiter is not waiting } std::unique_lock l(mu_); diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 075a5c1cac6bd..56ac186c377ba 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -30,214 +30,30 @@ #elif defined(_MSC_VER) #pragma warning(pop) #endif +#include "core/common/make_unique.h" #include "core/platform/ort_mutex.h" #include "core/platform/Barrier.h" namespace onnxruntime { +namespace concurrency { -class EventCount { - public: - class Waiter; - - explicit EventCount(Eigen::MaxSizeVector& waiters) : state_(kStackMask), waiters_(waiters) { - assert(waiters.size() < (1 << kWaiterBits) - 1); - } - -#ifdef NDEBUG - ~EventCount() = default; -#else - ~EventCount() { - // Ensure there are no waiters. - assert(state_.load() == kStackMask); - } -#endif - // Prewait prepares for waiting. - // After calling Prewait, the thread must re-check the wait predicate - // and then call either CancelWait or CommitWait. - void Prewait() { - uint64_t state = state_.load(std::memory_order_relaxed); - for (;;) { - CheckState(state); - uint64_t newstate = state + kWaiterInc; - CheckState(newstate); - if (state_.compare_exchange_weak(state, newstate, std::memory_order_seq_cst)) - return; - } - } - - // CommitWait commits waiting after Prewait. - void CommitWait(Waiter* w) { - assert((w->epoch & ~kEpochMask) == 0); - w->state = Waiter::kNotSignaled; - const uint64_t me = (w - &waiters_[0]) | w->epoch; - uint64_t state = state_.load(std::memory_order_seq_cst); - for (;;) { - CheckState(state, true); - uint64_t newstate; - if ((state & kSignalMask) != 0) { - // Consume the signal and return immidiately. - newstate = state - kWaiterInc - kSignalInc; - } else { - // Remove this thread from pre-wait counter and add to the waiter stack. - newstate = ((state & kWaiterMask) - kWaiterInc) | me; - w->next.store(state & (kStackMask | kEpochMask), std::memory_order_relaxed); - } - CheckState(newstate); - if (state_.compare_exchange_weak(state, newstate, std::memory_order_acq_rel)) { - if ((state & kSignalMask) == 0) { - w->epoch += kEpochInc; - Park(w); - } - return; - } - } - } - - // CancelWait cancels effects of the previous Prewait call. - void CancelWait() { - uint64_t state = state_.load(std::memory_order_relaxed); - for (;;) { - CheckState(state, true); - uint64_t newstate = state - kWaiterInc; - // We don't know if the thread was also notified or not, - // so we should not consume a signal unconditionaly. - // Only if number of waiters is equal to number of signals, - // we know that the thread was notified and we must take away the signal. - if (((state & kWaiterMask) >> kWaiterShift) == ((state & kSignalMask) >> kSignalShift)) - newstate -= kSignalInc; - CheckState(newstate); - if (state_.compare_exchange_weak(state, newstate, std::memory_order_acq_rel)) - return; - } - } - - // Notify wakes one or all waiting threads. - // Must be called after changing the associated wait predicate. - void Notify(bool notifyAll) { - std::atomic_thread_fence(std::memory_order_seq_cst); - uint64_t state = state_.load(std::memory_order_acquire); - for (;;) { - CheckState(state); - const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; - const uint64_t signals = (state & kSignalMask) >> kSignalShift; - // Easy case: no waiters. - if ((state & kStackMask) == kStackMask && waiters == signals) - return; - uint64_t newstate; - if (notifyAll) { - // Empty wait stack and set signal to number of pre-wait threads. - newstate = (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask; - } else if (signals < waiters) { - // There is a thread in pre-wait state, unblock it. - newstate = state + kSignalInc; - } else { - // Pop a waiter from list and unpark it. - Waiter* w = &waiters_[state & kStackMask]; - uint64_t next = w->next.load(std::memory_order_relaxed); - newstate = (state & (kWaiterMask | kSignalMask)) | next; - } - CheckState(newstate); - if (state_.compare_exchange_weak(state, newstate, std::memory_order_acq_rel)) { - if (!notifyAll && (signals < waiters)) - return; // unblocked pre-wait thread - if ((state & kStackMask) == kStackMask) - return; - Waiter* w = &waiters_[state & kStackMask]; - if (!notifyAll) - w->next.store(kStackMask, std::memory_order_relaxed); - Unpark(w); - return; - } - } - } - - class Waiter { - friend class EventCount; - // Align to 128 byte boundary to prevent false sharing with other Waiter - // objects in the same vector. - EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic next; - OrtMutex mu; - OrtCondVar cv; - uint64_t epoch = 0; - unsigned state = kNotSignaled; - enum { - kNotSignaled, - kWaiting, - kSignaled, - }; - }; - - private: - // State_ layout: - // - low kWaiterBits is a stack of waiters committed wait - // (indexes in waiters_ array are used as stack elements, - // kStackMask means empty stack). - // - next kWaiterBits is count of waiters in prewait state. - // - next kWaiterBits is count of pending signals. - // - remaining bits are ABA counter for the stack. - // (stored in Waiter node and incremented on push). - static constexpr uint64_t kWaiterBits = 14; - static constexpr uint64_t kStackMask = (1ull << kWaiterBits) - 1; - static constexpr uint64_t kWaiterShift = kWaiterBits; - static constexpr uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) << kWaiterShift; - static constexpr uint64_t kWaiterInc = 1ull << kWaiterShift; - static constexpr uint64_t kSignalShift = 2 * kWaiterBits; - static constexpr uint64_t kSignalMask = ((1ull << kWaiterBits) - 1) << kSignalShift; - static constexpr uint64_t kSignalInc = 1ull << kSignalShift; - static constexpr uint64_t kEpochShift = 3 * kWaiterBits; - static constexpr uint64_t kEpochBits = 64 - kEpochShift; - static constexpr uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift; - static constexpr uint64_t kEpochInc = 1ull << kEpochShift; - std::atomic state_; - Eigen::MaxSizeVector& waiters_; - -#ifdef NDEBUG - static void CheckState(uint64_t, bool) { - } - static void CheckState(uint64_t) { - } -#else - static void CheckState(uint64_t state, bool waiter = false) { - static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem"); - const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; - const uint64_t signals = (state & kSignalMask) >> kSignalShift; - assert(waiters >= signals); - assert(waiters < (1 << kWaiterBits) - 1); - assert(!waiter || waiters > 0); - (void)waiters; - (void)signals; - } -#endif - static void Park(Waiter* w) { - std::unique_lock lock(w->mu); - while (w->state != Waiter::kSignaled) { - w->state = Waiter::kWaiting; - w->cv.wait(lock); - } - } - - void Unpark(Waiter* w) { - for (Waiter* next; w; w = next) { - uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask; - next = wnext == kStackMask ? nullptr : &waiters_[static_cast(wnext)]; - unsigned state; - { - std::unique_lock lock(w->mu); - state = w->state; - w->state = Waiter::kSignaled; - } - // Avoid notifying if it wasn't waiting. - if (state == Waiter::kWaiting) - w->cv.notify_one(); - } - } +// Extended Eigen thread pool interface, avoiding the need to modify the ThreadPoolInterface.h +// header from the external Eigen repository. +class ExtendedThreadPoolInterface : public Eigen::ThreadPoolInterface { public: - EventCount(const EventCount&) = delete; - void operator=(const EventCount&) = delete; + // Run fn with up to n degree-of-parallelism enlisting the thread pool for + // help. The degree-of-parallelism includes the caller, and so if n==1 + // then the function will run directly in the caller. The fork-join + // synchronization is handled in the thread pool, and so any state captured + // by fn() is safe from concurrent access once RunInParallel returns. + virtual void RunInParallel(std::function fn, unsigned n) = 0; }; -template + +} // namespace concurrency + +template class RunQueue { public: RunQueue() : front_(0), back_(0) { @@ -245,7 +61,7 @@ class RunQueue { assert((kSize & (kSize - 1)) == 0); assert(kSize > 2); // why would you do this? assert(kSize <= (64 << 10)); // leave enough space for counter - for (unsigned i = 0; i < kSize; i++) array_[i].state.store(kEmpty, std::memory_order_relaxed); + for (unsigned i = 0; i < kSize; i++) array_[i].state.store(ElemState::kEmpty, std::memory_order_relaxed); } ~RunQueue() { @@ -256,26 +72,47 @@ class RunQueue { // If queue is full returns w, otherwise returns default-constructed Work. Work PushFront(Work w) { unsigned front = front_.load(std::memory_order_relaxed); - Elem* e = &array_[front & kMask]; - uint8_t s = e->state.load(std::memory_order_relaxed); - if (s != kEmpty || !e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire)) + Elem& e = array_[front & kMask]; + ElemState s = e.state.load(std::memory_order_relaxed); + if (s != ElemState::kEmpty || + !e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) return w; front_.store(front + 1 + (kSize << 1), std::memory_order_relaxed); - e->w = std::move(w); - e->state.store(kReady, std::memory_order_release); + e.w = std::move(w); + e.tag = Tag(); + e.state.store(ElemState::kReady, std::memory_order_release); return Work(); } // PopFront removes and returns the first element in the queue. // If the queue was empty returns default-constructed Work. Work PopFront() { - unsigned front = front_.load(std::memory_order_relaxed); - Elem* e = &array_[(front - 1) & kMask]; - uint8_t s = e->state.load(std::memory_order_relaxed); - if (s != kReady || !e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire)) + unsigned front; + Elem *e; + ElemState s; + + // Drain revoked items from the front of the queue. CAS to busy to synchronize with + // any attempt to take the same item from the back of the queue. + do { + front = front_.load(std::memory_order_relaxed); + e = &array_[(front - 1) & kMask]; + s = e->state.load(std::memory_order_relaxed); + if (s == ElemState::kRevoked && + e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) { + e->state.store(ElemState::kEmpty, std::memory_order_release); + front = ((front - 1) & kMask2) | (front & ~kMask2); + front_.store(front, std::memory_order_relaxed); + } + } while (s == ElemState::kRevoked); + + // Attempt to take next item. State kEmpty shows the queue is empty, kBusy shows + // the work is in progress on the item at the front of the queue. + if (s != ElemState::kReady || + !e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) return Work(); Work w = std::move(e->w); - e->state.store(kEmpty, std::memory_order_release); + e->tag = Tag(); + e->state.store(ElemState::kEmpty, std::memory_order_release); front = ((front - 1) & kMask2) | (front & ~kMask2); front_.store(front, std::memory_order_relaxed); return w; @@ -286,14 +123,39 @@ class RunQueue { Work PushBack(Work w) { std::unique_lock lock(mutex_); unsigned back = back_.load(std::memory_order_relaxed); - Elem* e = &array_[(back - 1) & kMask]; - uint8_t s = e->state.load(std::memory_order_relaxed); - if (s != kEmpty || !e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire)) + Elem& e = array_[(back - 1) & kMask]; + ElemState s = e.state.load(std::memory_order_relaxed); + if (s != ElemState::kEmpty || + !e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) + return w; + back = ((back - 1) & kMask2) | (back & ~kMask2); + back_.store(back, std::memory_order_relaxed); + e.w = std::move(w); + e.tag = Tag(); + e.state.store(ElemState::kReady, std::memory_order_release); + return Work(); + } + + // PushBackWithTag adds w at the end of the queue. The tag value can be used on a + // subsequent call to RevokeWithTag to remove the item from the queue in combination + // with w_idx. Typically the tag will be a per-thread ID to distinguish work + // submitted from different threads. + // + // If the queue is full, returns w, otherwise returns default-constructed work. + Work PushBackWithTag(Work w, Tag tag, unsigned &w_idx) { + std::unique_lock lock(mutex_); + unsigned back = back_.load(std::memory_order_relaxed); + w_idx = (back-1) & kMask; + Elem& e = array_[w_idx]; + ElemState s = e.state.load(std::memory_order_relaxed); + if (s != ElemState::kEmpty || + !e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) return w; back = ((back - 1) & kMask2) | (back & ~kMask2); back_.store(back, std::memory_order_relaxed); - e->w = std::move(w); - e->state.store(kReady, std::memory_order_release); + e.w = std::move(w); + e.tag = tag; + e.state.store(ElemState::kReady, std::memory_order_release); return Work(); } @@ -302,49 +164,76 @@ class RunQueue { if (Empty()) return Work(); std::unique_lock lock(mutex_); - unsigned back = back_.load(std::memory_order_relaxed); - Elem* e = &array_[back & kMask]; - uint8_t s = e->state.load(std::memory_order_relaxed); - if (s != kReady || !e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire)) + unsigned back; + Elem *e; + ElemState s; + + // Drain revoked items from the back of the queue. CAS to busy to synchronize with + // any attempt to take the same item from the front of the queue. + do { + back = back_.load(std::memory_order_relaxed); + e = &array_[back & kMask]; + s = e->state.load(std::memory_order_relaxed); + if (s == ElemState::kRevoked && + e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) { + e->state.store(ElemState::kEmpty, std::memory_order_release); + back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed); + } + } while (s == ElemState::kRevoked); + + if (s != ElemState::kReady || + !e->state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) return Work(); Work w = std::move(e->w); - e->state.store(kEmpty, std::memory_order_release); + e->tag = Tag(); + e->state.store(ElemState::kEmpty, std::memory_order_release); back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed); return w; } - // PopBackHalf removes and returns half last elements in the queue. - // Returns number of elements removed. - unsigned PopBackHalf(std::vector* result) { - if (Empty()) - return 0; + // RevokeItem removes a work item from the queue. Items are identified positionally, + // and so a tag is used to detect whether the same position is occupied by a + // different work item at the time of removal. RevokeWithTags lets threads offer work + // for parallel execution, and then revoke the offer prior to the work executing (for + // instance if the thread itself completes all of the work). Revoking the work + // lets the thread deallocate state that might otherwise have been captured by the work item + // and accessed by it. + // + // Return true iff the item is successfully revoked. If the item is not revoked then + // the caller must assume that it may still execute, for instance because it + // has been pop'd from the queue concurrent with the revocation request. + + bool RevokeWithTag(Tag tag, unsigned w_idx) { + bool revoked = false; std::unique_lock lock(mutex_); - unsigned back = back_.load(std::memory_order_relaxed); - unsigned size = Size(); - unsigned mid = back; - if (size > 1) - mid = back + (size - 1) / 2; - unsigned n = 0; - unsigned start = 0; - for (; static_cast(mid - back) >= 0; mid--) { - Elem* e = &array_[mid & kMask]; - uint8_t s = e->state.load(std::memory_order_relaxed); - if (n == 0) { - if (s != kReady || !e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire)) - continue; - start = mid; + Elem& e = array_[w_idx]; + ElemState s = e.state.load(std::memory_order_relaxed); + if (s == ElemState::kReady && + e.state.compare_exchange_strong(s, ElemState::kBusy, std::memory_order_acquire)) { + if (e.tag == tag) { + unsigned back = back_.load(std::memory_order_relaxed); + unsigned back_idx = back & kMask; + if (back_idx != w_idx) { + // Item is not at the back of the queue, mark it in-place as revoked + e.tag = Tag(); + e.w = Work(); + e.state.store(ElemState::kRevoked, std::memory_order_release); + revoked = true; + } else { + // Item being removed as still at the back; shift the back pointer over it, + // and bump the version number. + e.tag = Tag(); + e.w = Work(); + e.state.store(ElemState::kEmpty, std::memory_order_release); + back_.store(back + 1 + (kSize << 1), std::memory_order_relaxed); + revoked = true; + } } else { - // Note: no need to store temporal kBusy, we exclusively own these - // elements. - assert(s == kReady); + // Tag mismatch, i.e. work queue slot re-used + e.state.store(ElemState::kReady, std::memory_order_release); } - result->push_back(std::move(e->w)); - e->state.store(kEmpty, std::memory_order_release); - n++; } - if (n != 0) - back_.store(start + 1 + (kSize << 1), std::memory_order_relaxed); - return n; + return revoked; } // Size returns current queue size. @@ -369,15 +258,27 @@ class RunQueue { private: static const unsigned kMask = kSize - 1; static const unsigned kMask2 = (kSize << 1) - 1; - struct Elem { - std::atomic state; - Work w; - }; - enum { + + enum class ElemState : uint8_t { kEmpty, kBusy, kReady, + kRevoked, }; + + // Updates to an element are bracketed by a std::memory_order_acquire + // load from the state, and a std::memory_order_release store. Accesses + // to the front/back indices for the work queue use relaxed semantics, + // with the state of the elements being authoritative. + // + // TODO: Revisit whether there is a significant benefit for the current + // workloads in the complexity here. + struct Elem { + std::atomic state; + Tag tag; + Work w; + }; + OrtMutex mutex_; // Low log(kSize) + 1 bits in front_ and back_ contain rolling index of // front/back, respectively. The remaining bits contain modification counters @@ -438,8 +339,11 @@ class RunQueue { void operator=(const RunQueue&) = delete; }; +static std::atomic next_tag{1}; + template -class ThreadPoolTempl : public Eigen::ThreadPoolInterface { +class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInterface { + private: static unsigned WorkerLoop(int id, Eigen::ThreadPoolInterface* param) { // unsafe downcast @@ -450,7 +354,45 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { public: typedef typename Environment::Task Task; - typedef RunQueue Queue; + + struct Tag { + constexpr Tag() : v_(0) { + } + + Tag(uint32_t v) : v_(v) { + } + + // Allocate a new tag to use to identify work items from a given thread + // in RunInParallel. Ideally, threads will have unique tags, but re-use + // is not incorrect if the counter wraps (for intsance, if a long-running + // workload is calling into ORT from a fresh thread for each request). + // We must not re-use the default tag 0 which is used to identify work + // items added via Schedule as opposed to requests for help in RunInParallel. + + static Tag GetNext() { + Tag t = Tag(next_tag++); + if (t.v_ == 0) { + t = Tag(next_tag++); + } + return t; + } + + uint32_t Get() const { + return v_; + } + + bool operator==(Tag& other) const { + return v_ == other.v_; + } + + uint32_t v_ = 0; + }; + + static Tag GetNextTag() { + return Tag(next_tag++); + } + + typedef RunQueue Queue; #ifdef _WIN32 using CHAR_TYPE = wchar_t; #else @@ -463,14 +405,10 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { allow_spinning_(allow_spinning), thread_data_(num_threads), all_coprimes_(num_threads), - waiters_(num_threads), - global_steal_partition_(EncodePartition(0, num_threads_)), blocked_(0), - spinning_(false), done_(false), - cancelled_(false), - ec_(waiters_) { - waiters_.resize(num_threads_); + cancelled_(false) { + // Calculate coprimes of all numbers [1, num_threads]. // Coprimes are used for random walks over all threads in Steal // and NonEmptyQueueIndex. Iteration is based on the fact that if we take @@ -478,15 +416,19 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { // indices as (t + coprime) % num_threads, we will cover all threads without // repetitions (effectively getting a presudo-random permutation of thread // indices). - assert(num_threads_ < kMaxThreads); for (int i = 1; i <= num_threads_; ++i) { all_coprimes_.emplace_back(i); ComputeCoprimes(i, &all_coprimes_.back()); } + // Allocate space for per-thread bits to indicate which threads to consider + // preferable for pushing work. We use a regular array given that a std::vector + // cannot contain std::atomic. + num_hint_words_ = static_cast((num_threads_ + bits_per_hint_word_ - 1) / bits_per_hint_word_); + good_worker_hints_ = onnxruntime::make_unique[]>(num_hint_words_); + thread_data_.resize(num_threads_); for (int i = 0; i < num_threads_; i++) { - SetStealPartition(i, EncodePartition(0, num_threads_)); thread_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options)); } } @@ -498,7 +440,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { // But note that threads can continue to work arbitrary long, // block, submit new work, unblock and otherwise live full life. if (!cancelled_) { - ec_.Notify(true); + WakeAllWorkersForExit(); } else { // Since we were cancelled, there might be entries in the queues. // Empty them to prevent their destructor from asserting. @@ -511,25 +453,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { for (size_t i = 0; i < thread_data_.size(); ++i) thread_data_[i].thread.reset(); } - void SetStealPartitions(const std::vector>& partitions) { - assert(partitions.size() == static_cast(num_threads_)); - - // Pass this information to each thread queue. - for (int i = 0; i < num_threads_; i++) { - const auto& pair = partitions[i]; - unsigned start = pair.first; - unsigned end = pair.second; - AssertBounds(start, end); - unsigned val = EncodePartition(start, end); - SetStealPartition(i, val); - } - } - void Schedule(std::function fn) override { - ScheduleWithHint(std::move(fn), 0, num_threads_); - } - - void ScheduleWithHint(std::function fn, int start, int limit) override { Task t = env_.CreateTask(std::move(fn)); PerThread* pt = GetPerThread(); if (pt->pool == this) { @@ -539,77 +463,187 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { } else { // A free-standing thread (or worker of another pool), push onto a random // queue. - assert(start < limit); - assert(limit <= num_threads_); - int num_queues = limit - start; - int rnd = Rand(&pt->rand) % num_queues; - assert(start + rnd < limit); - Queue& q = thread_data_[start + rnd].queue; + int q_idx = Rand(&pt->rand) % num_threads_; + ThreadData &td = thread_data_[q_idx]; + Queue& q = td.queue; t = q.PushBack(std::move(t)); - } - // Note: below we touch this after making w available to worker threads. - // Strictly speaking, this can lead to a racy-use-after-free. Consider that - // Schedule is called from a thread that is neither main thread nor a worker - // thread of this pool. Then, execution of w directly or indirectly - // completes overall computations, which in turn leads to destruction of - // this. We expect that such scenario is prevented by program, that is, - // this is kept alive while any threads can potentially be in Schedule. - if (!t.f) { - ec_.Notify(false); - } else { - env_.ExecuteTask(t); // Push failed, execute directly. + if (t.f) { + // The queue rejected the work; run it directly + env_.ExecuteTask(t); + } else { + // The queue accepted the work; ensure that the thread will pick it up + td.EnsureAwake(); + } } } - void Cancel() override { - cancelled_ = true; - // If done_ is true, which means this object is being destructing. - // Therefore thread_data_[i].thread could be NULL. - if (!done_) { - done_ = true; - // Let each thread know it's been cancelled. - for (size_t i = 0; i < thread_data_.size(); i++) { - assert(thread_data_[i].thread != nullptr); - thread_data_[i].thread->OnCancel(); +// The thread pool maintains a set of hints for which threads will be good to distribute +// work to. A thread is considered "good" if it is actively spinning, meaning both that +// it is not busy with existing work, and that it should respond quickly to the addition +// of new work. + +void SetGoodWorkerHint(int idx, bool is_good) { + assert(idx >= 0 && idx < num_threads_); + std::atomic& u64 = good_worker_hints_[idx / bits_per_hint_word_]; + uint64_t bit = 1ull << (idx % bits_per_hint_word_); + uint64_t saw, want; + do { + saw = u64.load(); + want = is_good ? (saw|bit) : (saw&~bit); + } while (!u64.compare_exchange_weak(saw, want)); +} + +// Retrieve hints for up to n threads to distribute work to. Threads in good_hints +// pass a best-effort check to identify spinning threads via the good_worker_hints_ +// bitmap. Threads in alt_hint do not pass that test, but are distinct from those in +// good_hints, letting the caller avoid distributing more than one work item to +// any individual thread. + +void GetGoodWorkerHints(int n, std::vector& good_hints, std::vector& alt_hints) { + PerThread* pt = GetPerThread(); + int need_alt = n; + good_hints.clear(); + alt_hints.clear(); + + // Iterate through the words of hints, starting from a pseudo-randomly chosen + // base. This aims to distribute work across large machines in cases we + // have multiple threads scheduling work concurrently. + + unsigned base = Rand(&pt->rand) % num_hint_words_; + for (int i = 0; n && (i < num_hint_words_); i++) { + int u64_idx = (base + i) % num_hint_words_; + std::atomic* u64 = &good_worker_hints_[u64_idx]; + uint64_t saw = u64->load(); + uint64_t want = saw; + + // Pick up to n bits that are set in the current word + for (int j = 0; n && (j < bits_per_hint_word_); j++) { + uint64_t bit = 1ull << j; + int thread = u64_idx * bits_per_hint_word_ + j; + if (saw & bit) { + good_hints.push_back(thread); + want &= ~bit; + n--; + } else if (need_alt && thread < num_threads_) { + alt_hints.push_back(thread); + need_alt--; } } - // Wake up the threads without work to let them exit on their own. - ec_.Notify(true); - } + // Best-effort attempt to remove the hints. We should measure the impact of + // contention here, but the intuition is that if we conflict on the CAS then the + // machine is likely to be busy in any case, and we will have queuing on the + // work items. + u64->compare_exchange_strong(saw, want); + } +} + +void RunInParallel(std::function fn, unsigned n) override { + PerThread* pt = GetPerThread(); + assert(n>=1); + if (n == 1 || pt->in_parallel) { + fn(); + } else { + // We build a list of pairs for each of the queues that accepts a work + // item. This lets us remove any work items that do not get executed by the threads + // that we push them to. + std::vector> pending_items; + Barrier b(n); + + pt->in_parallel = true; + if (!pt->tag.Get()) { + pt->tag = Tag::GetNext(); + } + + // Push up to n-1 copies of the work item into the queues + std::vector& good_hints = pt->good_hints; + std::vector& alt_hints = pt->alt_hints; + GetGoodWorkerHints(n - 1, good_hints, alt_hints); + for (unsigned i = 0; i < n - 1; i++) { + Task t = env_.CreateTask([&b, &fn]() { + fn(); + b.Notify(1); + }); + int q_idx; + if (i < good_hints.size()) { + q_idx = good_hints[i]; + } else { + auto alt_i = i - static_cast(good_hints.size()); + if (alt_i < alt_hints.size()) { + q_idx = alt_hints[alt_i]; + } else { + q_idx = Rand(&pt->rand) % num_threads_; + } + } + ThreadData& td = thread_data_[q_idx]; + Queue& q = td.queue; + unsigned w_idx; + t = q.PushBackWithTag(std::move(t), pt->tag, w_idx); + if (t.f) { + // The queue rejected the work. Account for the missing capacity for work + // on the synchronization barrier. The semantics for RunInParallel are that + // the function is called with up to n-way parallelism, and so the + // work itself will be performed in the current thread's call to fn() + // after finishing adding work to the pool. + b.Notify(1); + } else { + // The queue accepted the work, ensure that the thread is servicing the queue + pending_items.push_back({q_idx, w_idx}); + td.EnsureAwake(); + } + } - int NumThreads() const EIGEN_FINAL { - return num_threads_; + // Run the final copy ourselves, for the total of n degree-of-parallelism + fn(); + + // Notify the barrier for the work we completed, plus any work that we successfully + // revoke from the work queues + int notifications_needed = 1; + for (auto& item : pending_items) { + Queue& q = thread_data_[item.first].queue; + if (q.RevokeWithTag(pt->tag, item.second)) { + notifications_needed++; + } + } + b.Notify(notifications_needed); + + // Synchronize with any work items that are still running + b.Wait(); + pt->in_parallel = false; } +} - int CurrentThreadId() const EIGEN_FINAL { - const PerThread* pt = const_cast(this)->GetPerThread(); - if (pt->pool == this) { - return pt->thread_id; +void Cancel() override { + cancelled_ = true; + // If done_ is true, which means this object is being destructing. + // Therefore thread_data_[i].thread could be NULL. + if (!done_) { + done_ = true; + // Let each thread know it's been cancelled. + for (size_t i = 0; i < thread_data_.size(); i++) { + assert(thread_data_[i].thread != nullptr); + thread_data_[i].thread->OnCancel(); } - return -1; } - private: - // Create a single atomic that encodes start and limit information for - // each thread. - // We expect num_threads_ < 65536, so we can store them in a single - // std::atomic. - // Exposed publicly as static functions so that external callers can reuse - // this encode/decode logic for maintaining their own thread-safe copies of - // scheduling and steal domain(s). - static const int kMaxPartitionBits = 16; - static const int kMaxThreads = 1 << kMaxPartitionBits; + // Wake up the threads without work to let them exit on their own. + WakeAllWorkersForExit(); +} - inline unsigned EncodePartition(unsigned start, unsigned limit) { - return (start << kMaxPartitionBits) | limit; - } +int NumThreads() const EIGEN_FINAL { + return num_threads_; +} - inline void DecodePartition(unsigned val, unsigned* start, unsigned* limit) { - *limit = val & (kMaxThreads - 1); - val >>= kMaxPartitionBits; - *start = val; +int CurrentThreadId() const EIGEN_FINAL { + const PerThread* pt = const_cast(this)->GetPerThread(); + if (pt->pool == this) { + return pt->thread_id; } + return -1; +} + + private: + #ifdef NDEBUG void AssertBounds(int, int) { } @@ -620,13 +654,6 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { assert(end <= num_threads_); } #endif - inline void SetStealPartition(size_t i, unsigned val) { - thread_data_[i].steal_partition.store(val, std::memory_order_relaxed); - } - - inline unsigned GetStealPartition(int i) { - return thread_data_[i].steal_partition.load(std::memory_order_relaxed); - } void ComputeCoprimes(int N, Eigen::MaxSizeVector* coprimes) { for (int i = 1; i <= N; i++) { @@ -645,21 +672,110 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { } typedef typename Environment::EnvThread Thread; + struct ThreadData; struct PerThread { constexpr PerThread() : pool(nullptr) { } - ThreadPoolTempl* pool; // Parent pool, or null for normal threads. - uint64_t rand{0}; // Random generator state. - int thread_id{-1}; // Worker thread index in pool. + ThreadPoolTempl* pool; // Parent pool, or null for normal threads. + uint64_t rand{0}; // Random generator state. + int thread_id{-1}; // Worker thread index in pool. + Tag tag{}; // Work item tag used to identify this thread. + bool in_parallel{false}; // Inside a parallel section (hence tag not unique if we re-use) + std::vector good_hints; // Vector used to pass hints of workers to use + std::vector alt_hints; // Vector used to pass hints of workers to use if not sufficient good hints }; struct ThreadData { - constexpr ThreadData() : thread(), steal_partition(0), queue() { + constexpr ThreadData() : thread(), queue() { } std::unique_ptr thread; - std::atomic steal_partition; Queue queue; + + // Each thread has a status, available read-only without locking, and protected + // by the mutex field below for updates. The status is used for three + // purposes: + // + // 1. To identify threads that are good candidates to push work to. + // We prefer to push work to threads that are actively spinning (no need + // for an OS wake-up, and no need for current work to finish). After that, we + // prefer to push work to threads that are blocked (no need to wait for the + // current work to finish). + // + // 2. To identify threads that are good candidates to steal work from. We + // prefer to steal work from threads that are active outside the worker loop. + // This avoids "snatching" new work away from a thread that has just been + // given it but not yet noticed. + // + // 3. When pushing work to a thread, we use the status read-only to identify + // when we need to wake the thread. This read-only check avoids the + // need for mutex / condvar operations in the case where the thread pool + // remains busy. + + enum class ThreadStatus : uint8_t { + Spinning, // Spinning in the work loop, and other cases (initialization) where + // the thread will soon be in the loop + Active, // Running user code, not waiting for work + Blocking, // In the process of blocking; may no longer notice work pushed to it + Blocked, // Blocked on cv + Waking, // Not yet back in the worker loop, but wake-up notification sent + }; + + ThreadStatus GetStatus() const { + return status; + } + + // State transitions, called from other threads + + void EnsureAwake() { + ThreadStatus seen = status; + if (seen == ThreadStatus::Blocking || + seen == ThreadStatus::Blocked) { + std::unique_lock lk(mutex); + // Blocking state exists only transiently during the SetBlock() method + // while holding the lock. We may observe it at the start of this + // function, but after acquiring the lock then the target thread + // will either be blocked or not. + seen = status; + assert(seen != ThreadStatus::Blocking); + if (seen == ThreadStatus::Blocked) { + status = ThreadStatus::Waking; + cv.notify_one(); + } + } + } + + // State transitions, called only from the thread itself + + void SetActive() { + std::unique_lock lk(mutex); + status = ThreadStatus::Active; + } + + void SetSpinning() { + std::unique_lock lk(mutex); + status = ThreadStatus::Spinning; + } + + void SetBlocked(std::function should_block, + std::function post_block) { + std::unique_lock lk(mutex); + assert(status == ThreadStatus::Spinning); + status = ThreadStatus::Blocking; + if (should_block()) { + status = ThreadStatus::Blocked; + while (status == ThreadStatus::Blocked) { + cv.wait(lk); + } + post_block(); + } + status = ThreadStatus::Spinning; + } + + private: + std::atomic status{ThreadStatus::Spinning}; + OrtMutex mutex; + OrtCondVar cv; }; Environment& env_; @@ -667,182 +783,180 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface { const bool allow_spinning_; Eigen::MaxSizeVector thread_data_; Eigen::MaxSizeVector> all_coprimes_; - Eigen::MaxSizeVector waiters_; - unsigned global_steal_partition_; - std::atomic blocked_; - std::atomic spinning_; + std::atomic blocked_; // Count of blocked workers, used as a termination condition std::atomic done_; std::atomic cancelled_; - EventCount ec_; + + // Allow control over how many bits to use in each entry in good_worker_hints_. + // We reduce this below the full 64-bit word size for two reasons. First, it + // helps test coverage on machines without 64 vCPUS. Second, it lets us + // reduce contention by having different threads start work searching for hints + // at different locations in the bitmap. + + static const int bits_per_hint_word_ = 4; + int num_hint_words_; + std::unique_ptr[]> good_worker_hints_; + + // Wake any blocked workers so that they can cleanly exit WorkerLoop(). For an + // abrupt exit, cancelled_==true and threads will exit their worker loops. For + // a clean exit, each thread will observe (1) done_ set, indicating that the + // destructor has been called, (2) all threads blocked, and (3) no + // items in the work queues. + + void WakeAllWorkersForExit() { + for (auto &td: thread_data_) { + td.EnsureAwake(); + } + } // Main worker thread loop. void WorkerLoop(int thread_id) { PerThread* pt = GetPerThread(); + ThreadData& td = thread_data_[thread_id]; + Queue& q = td.queue; + bool should_exit = false; pt->pool = this; pt->rand = GlobalThreadIdHash(); pt->thread_id = thread_id; - Queue& q = thread_data_[thread_id].queue; - EventCount::Waiter* waiter = &waiters_[thread_id]; - // TODO(dvyukov,rmlarsen): The time spent in NonEmptyQueueIndex() is - // proportional to num_threads_ and we assume that new work is scheduled at - // a constant rate, so we set spin_count to 5000 / num_threads_. The - // constant was picked based on a fair dice roll, tune it. - const int spin_count = allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0; - if (num_threads_ == 1) { - // For num_threads_ == 1 there is no point in going through the expensive - // steal loop. Moreover, since NonEmptyQueueIndex() calls PopBack() on the - // victim queues it might reverse the order in which ops are executed - // compared to the order in which they are scheduled, which tends to be - // counter-productive for the types of I/O workloads the single thread - // pools tend to be used for. - while (!cancelled_) { + + assert(td.GetStatus() == ThreadData::ThreadStatus::Spinning); + SetGoodWorkerHint(thread_id, true /* Is good */); + + const int log2_spin = 20; + const int spin_count = allow_spinning_ ? (1ull< check all queues */); + } if (!t.f) { - // Leave one thread spinning. This reduces latency. - if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) { - for (int i = 0; i < spin_count && !t.f; i++) { - if (!cancelled_.load(std::memory_order_relaxed)) { - t = GlobalSteal(); - } else { - return; - } - } - spinning_ = false; - } - if (!t.f) { - if (!WaitForWork(waiter, &t)) { - return; - } - } + td.SetBlocked( + // Pre-block test + [&]() -> bool { + bool should_block = true; + // We already did a best-effort emptiness check when stealing; now + // do a full check prior to blocking. + int victim = NonEmptyQueueIndex(); + if (victim != -1) { + should_block = false; + if (!cancelled_) { + t = thread_data_[victim].queue.PopBack(); + } + } + // Number of blocked threads is used as termination condition. + // If we are shutting down and all worker threads blocked without work, + // that's we are done. + if (should_block) { + blocked_++; + if (done_ && blocked_ == static_cast(num_threads_)) { + should_block = false; + // Almost done, but need to re-check queues. + // Consider that all queues are empty and all worker threads are preempted + // right after incrementing blocked_ above. Now a free-standing thread + // submits work and calls destructor (which sets done_). If we don't + // re-check queues, we will exit leaving the work unexecuted. + if (NonEmptyQueueIndex() != -1) { + // Note: we must not pop from queues before we decrement blocked_, + // otherwise the following scenario is possible. Consider that instead + // of checking for emptiness we popped the only element from queues. + // Now other worker threads can start exiting, which is bad if the + // work item submits other work. So we just check emptiness here, + // which ensures that all worker threads exit at the same time. + blocked_--; + } else { + should_exit = true; + } + } + } + return should_block; + }, + // Post-block update (executed only if we blocked) + [&]() { + blocked_--; + }); } } } if (t.f) { + td.SetActive(); env_.ExecuteTask(t); + td.SetSpinning(); } } + + // Whichever thread(s) observe the termination conditions are responsible for waking + // any other threads that have remained blocked. + if (should_exit) { + WakeAllWorkersForExit(); + } } - } // Steal tries to steal work from other worker threads in the range [start, - // limit) in best-effort manner. - Task Steal(unsigned start, unsigned limit) { + // limit) in best-effort manner. We make two passes over the threads: + // + // - round 0 : we attempt to steal from threads that are running in + // user code (ThreadStatus::Active). The intuition behind this is that + // the thread is busy with other work, and that by preferring to + // steel from busy victims we will avoid "snatching" work from a + // thread which is just about to notice the work itself. + // + // - round 1 : we steal work from any thread, including those which claim + // to be spinning. In these cases, even though the victim thread is + // looking for work itself, it may have been pre-empted. + + Task Steal(bool check_all) { PerThread* pt = GetPerThread(); - const unsigned size = static_cast(limit - start); + unsigned size = static_cast(num_threads_); unsigned r = Rand(&pt->rand); - unsigned victim = r % size; unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()]; - for (unsigned i = 0; i < size; i++) { - assert(start + victim < limit); - Task t = thread_data_[start + victim].queue.PopBack(); - if (t.f) { - return t; - } - victim += inc; - if (victim >= size) { - victim -= size; + for (int round = 0; round < 2; round++) { + unsigned victim = r % size; + for (unsigned i = 0; i < size; i++) { + assert(victim < size); + if (round == 1 || + thread_data_[victim].GetStatus() == ThreadData::ThreadStatus::Active) { + Task t = thread_data_[victim].queue.PopBack(); + if (t.f) { + return t; + } + } + if (!check_all) { + return Task(); + } + victim += inc; + if (victim >= size) { + victim -= size; + } } } + return Task(); } - // Steals work within threads belonging to the partition. - Task LocalSteal() { - PerThread* pt = GetPerThread(); - unsigned partition = GetStealPartition(pt->thread_id); - // If thread steal partition is the same as global partition, there is no - // need to go through the steal loop twice. - if (global_steal_partition_ == partition) - return Task(); - unsigned start; - unsigned limit; - DecodePartition(partition, &start, &limit); - AssertBounds(start, limit); - - return Steal(start, limit); - } - - // Steals work from any other thread in the pool. - Task GlobalSteal() { - return Steal(0, num_threads_); - } - - // WaitForWork blocks until new work is available (returns true), or if it is - // time to exit (returns false). Can optionally return a task to execute in t - // (in such case t.f != nullptr on return). - bool WaitForWork(EventCount::Waiter* waiter, Task* t) { - assert(!t->f); - // We already did best-effort emptiness check in Steal, so prepare for - // blocking. - ec_.Prewait(); - // Now do a reliable emptiness check. - int victim = NonEmptyQueueIndex(); - if (victim != -1) { - ec_.CancelWait(); - if (cancelled_) { - return false; - } - *t = thread_data_[victim].queue.PopBack(); - return true; - } - // Number of blocked threads is used as termination condition. - // If we are shutting down and all worker threads blocked without work, - // that's we are done. - blocked_++; - // TODO is blocked_ required to be unsigned? - if (done_ && blocked_ == static_cast(num_threads_)) { - ec_.CancelWait(); - // Almost done, but need to re-check queues. - // Consider that all queues are empty and all worker threads are preempted - // right after incrementing blocked_ above. Now a free-standing thread - // submits work and calls destructor (which sets done_). If we don't - // re-check queues, we will exit leaving the work unexecuted. - if (NonEmptyQueueIndex() != -1) { - // Note: we must not pop from queues before we decrement blocked_, - // otherwise the following scenario is possible. Consider that instead - // of checking for emptiness we popped the only element from queues. - // Now other worker threads can start exiting, which is bad if the - // work item submits other work. So we just check emptiness here, - // which ensures that all worker threads exit at the same time. - blocked_--; - return true; - } - // Reached stable termination state. - ec_.Notify(true); - return false; - } - ec_.CommitWait(waiter); - blocked_--; - return true; + Task TrySteal() { + return Steal(false); } int NonEmptyQueueIndex() { PerThread* pt = GetPerThread(); - // We intentionally design NonEmptyQueueIndex to steal work from - // anywhere in the queue so threads don't block in WaitForWork() forever - // when all threads in their partition go to sleep. Steal is still local. const unsigned size = static_cast(thread_data_.size()); unsigned r = Rand(&pt->rand); unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()]; diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index 4a9d7620434e6..93ec39ecf5d75 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -44,7 +44,12 @@ struct TensorOpCost { template class ThreadPoolTempl; + namespace concurrency { + +class ExtendedThreadPoolInterface; +class BatchHandle; + class ThreadPool { public: // Scheduling strategies for ParallelFor. The strategy governs how the given @@ -123,13 +128,11 @@ class ThreadPool { // // REQUIRES: num_threads > 0 // The allocator parameter is only used for creating a Eigen::ThreadPoolDevice to be used with Eigen Tensor classes. - ThreadPool(Env* env, const ThreadOptions& thread_options, const NAME_CHAR_TYPE* name, int num_threads, + ThreadPool(Env* env, + const ThreadOptions& thread_options, + const NAME_CHAR_TYPE* name, + int num_threads, bool low_latency_hint); - // Constructs a pool that wraps around the thread::ThreadPoolInterface - // instance provided by the caller. Caller retains ownership of - // `user_threadpool` and must ensure its lifetime is longer than the - // ThreadPool instance. - ThreadPool(Eigen::ThreadPoolInterface* user_threadpool); // Waits until all scheduled work has finished and then destroy the // set of threads. @@ -140,7 +143,8 @@ class ThreadPool { // Returns the number of shards used by ParallelForFixedBlockSizeScheduling // with these parameters. - int NumShardsUsedByFixedBlockSizeScheduling(std::ptrdiff_t total, std::ptrdiff_t block_size); + int NumShardsUsedByFixedBlockSizeScheduling(std::ptrdiff_t total, + std::ptrdiff_t block_size) const; // ParallelFor shards the "total" units of work assuming each unit of work // having roughly "cost_per_unit" cost, in cycles. Each unit of work is @@ -224,11 +228,6 @@ class ThreadPool { // thread in the pool. Returns -1 otherwise. int CurrentThreadId() const; - // If ThreadPool implementation is compatible with Eigen::ThreadPoolInterface, - // returns a non-null pointer. The caller does not own the object the returned - // pointer points to, and should not attempt to delete. - Eigen::ThreadPoolInterface* AsEigenThreadPool() const; - // Directly schedule the 'total' tasks to the underlying threadpool, without // cutting them by halves void SimpleParallelFor(std::ptrdiff_t total, const std::function& fn); @@ -335,6 +334,13 @@ class ThreadPool { ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool); private: + // Run fn with up to n degree-of-parallelism enlisting the thread pool for + // help. The degree-of-parallelism includes the caller, and so if n==1 + // then the function will run directly in the caller. The fork-join + // synchronization is handled in the thread pool, and so any state captured + // by fn() is safe from concurrent access once RunWithHelp returns. + void RunInParallel(std::function fn, int n); + // Divides the work represented by the range [0, total) into k shards. // Calls fn(i*block_size, (i+1)*block_size) from the ith shard (0 <= i < k). // Each shard may be executed on a different thread in parallel, depending on @@ -344,13 +350,21 @@ class ThreadPool { // Requires 0 < block_size <= total. void ParallelForFixedBlockSizeScheduling(std::ptrdiff_t total, std::ptrdiff_t block_size, const std::function& fn); + + + // Return whether or not the calling thread should run a loop of + // num_iterations divided in chunks of block_size in parallel. If not, + // the caller should run the loop sequentially. + bool ShouldParallelizeLoop(const std::ptrdiff_t num_iterations, + const std::ptrdiff_t block_size = 1) const; + ThreadOptions thread_options_; // underlying_threadpool_ is the user_threadpool if user_threadpool is // provided in the constructor. Otherwise it is the eigen_threadpool_. - Eigen::ThreadPoolInterface* underlying_threadpool_; + ExtendedThreadPoolInterface* underlying_threadpool_; // eigen_threadpool_ is instantiated and owned by thread::ThreadPool if // user_threadpool is not in the constructor. - std::unique_ptr > eigen_threadpool_; + std::unique_ptr > extended_eigen_threadpool_; }; } // namespace concurrency diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index f9ddc902a5222..a89b0ce3be013 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -1,4 +1,3 @@ - /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,101 +22,164 @@ limitations under the License. #include "core/platform/ort_mutex.h" namespace onnxruntime { -namespace { -class BlockingCounter { - public: - BlockingCounter(int initial_count) : state_(initial_count << 1), notified_(false) { - ORT_ENFORCE(initial_count >= 0); -#ifndef NDEBUG - ORT_ENFORCE(((initial_count << 1) >> 1) == initial_count); -#endif - } - ~BlockingCounter() = default; +namespace concurrency { - inline void DecrementCount() { - unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; - if (v != 1) { -#ifndef NDEBUG - ORT_ENFORCE(((v + 2) & ~1) != 0); +// A sharded loop counter distributes loop iterations between a set of worker threads. The iteration space of +// the loop is divided (perhaps unevenly) between the shards. Each thread has a home shard (perhaps not uniquely +// to it), and it claims iterations via atomic operations on its home shard. It then proceeds through the other +// shards until all of the shards' iterations are complete. This approach serves to purposes. First, compared +// with atomic operations on a single counter, it reduces contention on a single counter in the case of loops with +// large numbers of short-running iteration. Second, by having a thread work on its home shard initially, it +// promotes affinity between the work that a thread performs in one loop and the work that it performs in the next. + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4324) /* Padding added to LoopCounterShard, LoopCounter for alignment */ #endif - return; // either count has not dropped to 0, or waiter is not waiting - } - std::lock_guard l(mu_); - notified_ = true; - cond_var_.notify_all(); - } - inline void Wait() { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) - return; - std::unique_lock l(mu_); - while (!notified_) { - cond_var_.wait(l); +static constexpr int CACHE_LINE_BYTES = 64; +static constexpr int NUM_SHARDS = 8; + +struct alignas(CACHE_LINE_BYTES) LoopCounterShard { + ::std::atomic _next; + uint64_t _end; +}; + +class alignas(CACHE_LINE_BYTES) LoopCounter { +public: + LoopCounter(const ThreadPool& tp, + uint64_t num_iterations, + uint64_t block_size = 1) : _tp(tp), + _block_size(block_size) { + assert(sizeof(LoopCounterShard) == 64); + assert(block_size != 0); + + // Divide the iteration space into NUM_SHARDS pieces. If the iteration space does not + // divide evenly into shards of multiples of block_size then the final shard is left uneven. + double iterations_per_shard = static_cast(num_iterations) / NUM_SHARDS; + uint64_t split = 0; + for (uint64_t shard = 0; shard < NUM_SHARDS; shard++) { + _shards[shard]._next = split; + split = (static_cast((shard + 1) * iterations_per_shard) / block_size) * block_size; + _shards[shard]._end = split; + } + + // Ensure that the final shard finishes precisely at the end of the iteration space + _shards[NUM_SHARDS - 1]._end = num_iterations; + } + + int GetHomeShard() const { + // Allocate each thread to a home shard, from which it starts claiming iterations. The allocation + // does not need to be unique, but we aim for a good distribution, particularly in the case where + // most/all of the thread pool's threads are active in the loop. Threads outside the pool may + // also be claiming work, with CurrentThreadId -1. + int num_threads = _tp.NumThreads(); + int my_thread_idx = (_tp.CurrentThreadId() + 1) % num_threads; + assert(my_thread_idx >= 0 && my_thread_idx < num_threads); + + int home_shard; + if (num_threads >= NUM_SHARDS) { + // More threads than shards => allocate them home shards round-robin, aiming to sprace the load across + // the shards + home_shard = my_thread_idx % NUM_SHARDS; + } else { + // Fewer threads than shards => spread the threads evenly across the shards, so each will work + // on a run of successive shards before contention + home_shard = (my_thread_idx * NUM_SHARDS) / num_threads; } + assert(home_shard >= 0 && home_shard < NUM_SHARDS); + return home_shard; } - // Wait for the specified time, return false iff the count has not dropped to - // zero before the timeout expired. - inline bool WaitFor(std::chrono::milliseconds ms) { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) - return true; - std::unique_lock l(mu_); - while (!notified_) { - const std::cv_status status = cond_var_.wait_for(l, ms); - if (status == std::cv_status::timeout) { - return false; + + // Attempt to claim iterations from the sharded counter. The function either + // returns true, along with a block of exactly block_size iterations, or it returns false + // if all of the iterations have been claimed. + bool ClaimIterations(int my_home_shard, + int& my_shard, + uint64_t& my_start, + uint64_t& my_end) { + do { + if (_shards[my_shard]._next < _shards[my_shard]._end) { + // Appears to be work in the current shard, try to claim with atomic fetch-and-add + uint64_t temp_start = _shards[my_shard]._next.fetch_add(_block_size); + if (temp_start < _shards[my_shard]._end) { + my_start = temp_start; + my_end = std::min(_shards[my_shard]._end, temp_start + _block_size); + return true; + } } - } - return true; + // Work in the current shard is exhausted, move to the next shard, until + // we are back at the home shard. + my_shard = (my_shard + 1) % NUM_SHARDS; + } while (my_shard != my_home_shard); + return false; } - private: - OrtMutex mu_; - OrtCondVar cond_var_; - std::atomic state_; // low bit is waiter flag - bool notified_; +private: + alignas(CACHE_LINE_BYTES) LoopCounterShard _shards[NUM_SHARDS]; + const ThreadPool& _tp; + const uint64_t _block_size; }; -} // namespace -namespace concurrency { + +#ifdef _MSC_VER +#pragma warning(pop) /* Padding added in LoopCounterShard, LoopCounter */ +#endif ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, const NAME_CHAR_TYPE* name, int num_threads, bool low_latency_hint) : thread_options_(thread_options) { ORT_ENFORCE(num_threads >= 1); - eigen_threadpool_ = + extended_eigen_threadpool_ = onnxruntime::make_unique>(name, num_threads, low_latency_hint, *env, thread_options_); - underlying_threadpool_ = eigen_threadpool_.get(); -} - -ThreadPool::ThreadPool(Eigen::ThreadPoolInterface* user_threadpool) - : thread_options_(ThreadOptions()) { - underlying_threadpool_ = user_threadpool; + underlying_threadpool_ = extended_eigen_threadpool_.get(); } ThreadPool::~ThreadPool() = default; -void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function& fn) { +// Base case for parallel loops, running iterations 0..total, divided into blocks +// of block_size iterations, and calling into a function that takes a start..end +// range of indices to run. +void ThreadPool::ParallelForFixedBlockSizeScheduling(const std::ptrdiff_t total, + const std::ptrdiff_t block_size, + const std::function& fn) { if (total <= 0) return; - if (total == 1) { - fn(0); + if (total <= block_size) { + fn(0, total); return; } - Barrier barrier(static_cast(total)); - std::function handle_iteration = [&barrier, &fn](std::ptrdiff_t iteration) { - fn(iteration); - barrier.Notify(); + // Split the work across threads in the pool. Each work item will run a loop claiming iterations, + // hence we need at most one for each thread, even if the numberof blocks of iterations is larger. + int num_threads = NumThreads(); + int num_work_items = static_cast(std::min(static_cast(num_threads), total)); + assert(num_work_items > 0); + + LoopCounter lc(*this, total, block_size); + std::function run_work = [&]() { + int my_home_shard = lc.GetHomeShard(); + int my_shard = my_home_shard; + uint64_t my_iter_start, my_iter_end; + while (lc.ClaimIterations(my_home_shard, my_shard, my_iter_start, my_iter_end)) { + fn(static_cast(my_iter_start), + static_cast(my_iter_end)); + } }; - for (std::ptrdiff_t id = 0; id < total; ++id) { - Schedule([=, &handle_iteration]() { handle_iteration(id); }); - } + // Run the work in the thread pool (and in the current thread). Synchronization with helping + // threads is handled within RunInParallel, hence we can deallocate lc and other state captured by + // run_work. + RunInParallel(run_work, num_work_items); +} - barrier.Wait(); +void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function& fn) { + ParallelForFixedBlockSizeScheduling(total, 1, [&](std::ptrdiff_t first, std::ptrdiff_t last) { + for (std::ptrdiff_t idx = first; idx < last; idx++) { + fn(idx); + } + }); } void ThreadPool::Schedule(std::function fn) { @@ -125,12 +187,37 @@ void ThreadPool::Schedule(std::function fn) { underlying_threadpool_->Schedule(std::move(fn)); } -int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(const std::ptrdiff_t total, const std::ptrdiff_t block_size) { - if (block_size <= 0 || total <= 1 || total <= block_size || NumThreads() == 1) { +void ThreadPool::RunInParallel(std::function fn, int n) { + ORT_ENFORCE(fn != nullptr); + underlying_threadpool_->RunInParallel(std::move(fn), n); +} + +bool ThreadPool::ShouldParallelizeLoop(const std::ptrdiff_t num_iterations, + const std::ptrdiff_t block_size) const { + // Do not parallelize trivial loops, with only a single block of work + if (block_size <= 0 || num_iterations <= block_size) { + return false; + } + + // Do not parallelize loops with only a single thread available. If the + // caller is outside the current pool (ID == -1) then we parallelize + // via the pool's thread(s). If the caller is inside the current pool + // (ID != -1) then we require at least one additional thread in the pool. + if (CurrentThreadId() != -1 && NumThreads() == 1) { + return false; + } + + return true; +} + +int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(const std::ptrdiff_t total, + const std::ptrdiff_t block_size) const { + if (!ShouldParallelizeLoop(total, block_size)) { return 1; + } else { + // TODO:check overflow? + return static_cast((total + block_size - 1) / block_size); } - // TODO:check overflow? - return static_cast((total + block_size - 1) / block_size); } void ThreadPool::ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params, @@ -151,41 +238,6 @@ void ThreadPool::ParallelFor(std::ptrdiff_t total, const SchedulingParams& sched } } -// This functionality is similar to parallelFor, except that reasoning about -// the number of shards used is significantly easier. -void ThreadPool::ParallelForFixedBlockSizeScheduling(const std::ptrdiff_t total, const std::ptrdiff_t block_size, - const std::function& fn) { - const int num_shards_used = NumShardsUsedByFixedBlockSizeScheduling(total, block_size); - if (num_shards_used == 1) { - fn(0, total); - return; - } - - // Adapted from Eigen's parallelFor implementation. - BlockingCounter counter(num_shards_used); - std::function handle_range = [=, &handle_range, &counter, &fn](std::ptrdiff_t first, - std::ptrdiff_t last) { - while (last - first > block_size) { - // Find something near the midpoint which is a multiple of block size. - const std::ptrdiff_t mid = first + ((last - first) / 2 + block_size - 1) / block_size * block_size; - Schedule([=, &handle_range]() { handle_range(mid, last); }); - last = mid; - } - // Single block or less, execute directly. - fn(first, last); - counter.DecrementCount(); // The shard is done. - }; - - // Execute the root in the thread pool to avoid running work on more than - // numThreads() threads. - Schedule([=, &handle_range]() { handle_range(0, total); }); - counter.Wait(); -} - -struct ParallelForBlock { - ptrdiff_t size; // block size - ptrdiff_t count; // number of blocks -}; using CostModel = Eigen::TensorCostModel; // Calculates block size based on (1) the iteration cost and (2) parallel @@ -193,8 +245,8 @@ using CostModel = Eigen::TensorCostModel; // overheads; not too large to mitigate tail effect and potential load // imbalance and we also want number of blocks to be evenly dividable across // threads. -static ParallelForBlock CalculateParallelForBlock(const ptrdiff_t n, const Eigen::TensorOpCost& cost, - std::function block_align, int num_threads) { +static ptrdiff_t CalculateParallelForBlock(const ptrdiff_t n, const Eigen::TensorOpCost& cost, + std::function block_align, int num_threads) { const double block_size_f = 1.0 / CostModel::taskSize(1, cost); const ptrdiff_t max_oversharding_factor = 4; ptrdiff_t block_size = Eigen::numext::mini( @@ -245,7 +297,7 @@ static ParallelForBlock CalculateParallelForBlock(const ptrdiff_t n, const Eigen } } - return {block_size, block_count}; + return block_size; } void ThreadPool::ParallelFor(std::ptrdiff_t n, const TensorOpCost& c, @@ -253,35 +305,16 @@ void ThreadPool::ParallelFor(std::ptrdiff_t n, const TensorOpCost& c, ORT_ENFORCE(n >= 0); Eigen::TensorOpCost cost{c.bytes_loaded, c.bytes_stored, c.compute_cycles}; // Compute small problems directly in the caller thread. - if (n <= 1 || NumThreads() == 1 || + if ((!ShouldParallelizeLoop(n)) || Eigen::TensorCostModel::numThreads(static_cast(n), cost, static_cast(NumThreads())) == 1) { f(0, n); return; } - // Compute block size and total count of blocks. - ParallelForBlock block = CalculateParallelForBlock(n, cost, nullptr, NumThreads()); - - // Recursively divide size into halves until we reach block_size. - // Division code rounds mid to block_size, so we are guaranteed to get - // block_count leaves that do actual computations. - Barrier barrier(static_cast(block.count)); - std::function handleRange; - handleRange = [=, &handleRange, &barrier, &f](ptrdiff_t firstIdx, ptrdiff_t lastIdx) { - while (lastIdx - firstIdx > block.size) { - // Split into halves and schedule the second half on a different thread. - const ptrdiff_t midIdx = firstIdx + Eigen::divup((lastIdx - firstIdx) / 2, block.size) * block.size; - underlying_threadpool_->Schedule([=, &handleRange]() { handleRange(midIdx, lastIdx); }); - lastIdx = midIdx; - } - // Single block or less, execute directly. - f(firstIdx, lastIdx); - barrier.Notify(); - }; - - underlying_threadpool_->Schedule([=, &handleRange]() { handleRange(0, n); }); - barrier.Wait(); + ptrdiff_t block = CalculateParallelForBlock(n, cost, nullptr, NumThreads()); + ParallelForFixedBlockSizeScheduling(n, block, f); } + void ThreadPool::ParallelFor(std::ptrdiff_t total, double cost_per_unit, const std::function& fn) { ParallelFor(total, TensorOpCost{0, 0, static_cast(cost_per_unit)}, fn); @@ -300,13 +333,11 @@ int ThreadPool::NumThreads() const { return underlying_threadpool_->NumThreads(); } +// Return ID of the current thread within this pool. Returns -1 for a thread outside the +// current pool. int ThreadPool::CurrentThreadId() const { return underlying_threadpool_->CurrentThreadId(); } -Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const { - ORT_ENFORCE(underlying_threadpool_ != nullptr); - return underlying_threadpool_; -} } // namespace concurrency } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index be2184c92b8a8..bed62615ff72f 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -81,13 +81,13 @@ void PerformanceResult::DumpToFile(const std::basic_string& path, boo std::sort(sorted_time.begin(), sorted_time.end()); auto output_stats = [&](std::ostream& ostream) { - ostream << "Min Latency is " << sorted_time[0] << "sec\n"; - ostream << "Max Latency is " << sorted_time[total - 1] << "sec\n"; - ostream << "P50 Latency is " << sorted_time[n50] << "sec\n"; - ostream << "P90 Latency is " << sorted_time[n90] << "sec\n"; - ostream << "P95 Latency is " << sorted_time[n95] << "sec\n"; - ostream << "P99 Latency is " << sorted_time[n99] << "sec\n"; - ostream << "P999 Latency is " << sorted_time[n999] << "sec" << std::endl; + ostream << "Min Latency: " << sorted_time[0] << " s\n"; + ostream << "Max Latency: " << sorted_time[total - 1] << " s\n"; + ostream << "P50 Latency: " << sorted_time[n50] << " s\n"; + ostream << "P90 Latency: " << sorted_time[n90] << " s\n"; + ostream << "P95 Latency: " << sorted_time[n95] << " s\n"; + ostream << "P99 Latency: " << sorted_time[n99] << " s\n"; + ostream << "P999 Latency: " << sorted_time[n999] << " s" << std::endl; }; if (have_file) { @@ -132,12 +132,16 @@ Status PerformanceRunner::Run() { // if (!performance_test_config_.run_config.profile_file.empty()) session_object->EndProfiling(); std::chrono::duration inference_duration = performance_result_.end - performance_result_.start; - std::cout << "Session creation time cost:" << session_create_duration.count() << " s" << std::endl - << "Total inference time cost:" << performance_result_.total_time_cost << " s" << std::endl // sum of time taken by each request - << "Total inference requests:" << performance_result_.time_costs.size() << std::endl - << "Average inference time cost:" << performance_result_.total_time_cost / performance_result_.time_costs.size() * 1000 << " ms" << std::endl + std::cout << "Session creation time cost: " << session_create_duration.count() << " s\n" + << "Total inference time cost: " << performance_result_.total_time_cost << " s\n" // sum of time taken by each request + << "Total inference requests: " << performance_result_.time_costs.size() << "\n" + << "Average inference time cost: " << performance_result_.total_time_cost / performance_result_.time_costs.size() * 1000 << " ms\n" // Time between start and end of run. Less than Total time cost when running requests in parallel. - << "Total inference run time:" << inference_duration.count() << " s" << std::endl; + << "Total inference run time: " << inference_duration.count() << " s\n" + << "Avg CPU usage: " << performance_result_.average_CPU_usage << " %\n" + << "Peak working set size: " << performance_result_.peak_workingset_size << " bytes" + << std::endl; + return Status::OK(); } diff --git a/onnxruntime/test/platform/threadpool_test.cc b/onnxruntime/test/platform/threadpool_test.cc index 0bfcb2ae76ed0..d7ab072be6da3 100644 --- a/onnxruntime/test/platform/threadpool_test.cc +++ b/onnxruntime/test/platform/threadpool_test.cc @@ -67,6 +67,43 @@ void TestBatchParallelFor(const std::string& name, int num_threads, int num_task ValidateTestData(*test_data); } +void TestMultipleParallelFor(const std::string& name, int num_threads, int num_concurrent, int num_tasks) { + // Test running multiple concurrent loops over the same thread pool. This aims to provoke a + // more diverse mix of interleavings than with a single loop running at a time. + for (int rep = 0; rep < 5; rep++) { + CreateThreadPoolAndTest(name, num_threads, [&](ThreadPool* tp) { + std::vector> td; + onnxruntime::Barrier b(num_concurrent - 1); + + // Each concurrent tests runs with its own set of counters + for (int c = 0; c < num_concurrent; c++) { + td.push_back(CreateTestData(num_tasks)); + } + + // For a range of scenarios, run some tests via the thread pool, and one directly + for (int c = 0; c < num_concurrent - 1; c++) { + tp->Schedule([&, c]() { + tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { + IncrementElement(*td[c], i); + }); + b.Notify(); + }); + } + + tp->SimpleParallelFor(num_tasks, [&](std::ptrdiff_t i) { + IncrementElement(*td[num_concurrent - 1], i); + }); + + // Validate all outputs + b.Wait(); + for (int c = 0; c < num_concurrent; c++) { + ValidateTestData(*td[c]); + } + td.clear(); + }); + } +} + } // namespace namespace onnxruntime { @@ -102,6 +139,69 @@ TEST(ThreadPoolTest, TestBatchParallelFor_2_Thread_81_Task_20_Batch) { TestBatchParallelFor("TestBatchParallelFor_2_Thread_81_Task_20_Batch", 2, 81, 20); } +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_1Conc_0Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_1Conc_0Tasks", 1, 1, 0); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_1Conc_1Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_1Conc_1Tasks", 1, 1, 1); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_1Conc_8Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_1Conc_8Tasks", 1, 1, 8); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_1Conc_1MTasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_1Conc_1MTasks", 1, 1, 1000000); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_4Conc_0Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_4Conc_0Tasks", 1, 4, 0); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_4Conc_1Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_4Conc_1Tasks", 1, 4, 1); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_4Conc_8Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_4Conc_8Tasks", 1, 4, 8); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_1Thread_4Conc_1MTasks) { + TestMultipleParallelFor("TestMultipleParallelFor_1Thread_4Conc_1MTasks", 1, 4, 1000000); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_1Conc_0Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_0Tasks", 4, 1, 0); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_1Conc_1Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_1Tasks", 4, 1, 1); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_1Conc_8Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_8Tasks", 4, 1, 8); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_1Conc_1MTasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_1MTasks", 4, 1, 1000000); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_4Conc_0Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_0Tasks", 4, 4, 0); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_4Conc_1Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_1Tasks", 4, 4, 1); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_4Conc_8Tasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_8Tasks", 4, 4, 8); +} + +TEST(ThreadPoolTest, TestMultipleParallelFor_4Thread_4Conc_1MTasks) { + TestMultipleParallelFor("TestMultipleParallelFor_4Thread_4Conc_1MTasks", 4, 4, 1000000); +} #ifdef _WIN32 TEST(ThreadPoolTest, TestStackSize) { ThreadOptions to;