From 0f05c65492e38f89012e5595c6d60bb67e2d418f Mon Sep 17 00:00:00 2001 From: Eric Junyuan Xie Date: Fri, 8 Dec 2017 10:14:57 -0800 Subject: [PATCH] Fix race condition in engine start/stop (#8995) --- src/engine/threaded_engine_perdevice.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 28bc92f7b207..e7e222f6cbe3 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -95,9 +95,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { int cpu_priority_nthreads = dmlc::GetEnv("MXNET_CPU_PRIORITY_NTHREADS", 4); cpu_priority_worker_.reset(new ThreadWorkerBlock()); cpu_priority_worker_->pool.reset(new ThreadPool( - cpu_priority_nthreads, [this]() { - this->CPUWorker(Context(), cpu_priority_worker_.get()); - })); + cpu_priority_nthreads, + [this](std::shared_ptr ready_event) { + this->CPUWorker(Context(), cpu_priority_worker_.get(), ready_event); + }, true)); // GPU tasks will be created lazily } @@ -122,9 +123,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { auto ptr = cpu_normal_workers_.Get(dev_id, [this, ctx, nthread]() { auto blk = new ThreadWorkerBlock(); - blk->pool.reset(new ThreadPool(nthread, [this, ctx, blk] () { - this->CPUWorker(ctx, blk); - })); + blk->pool.reset(new ThreadPool(nthread, + [this, ctx, blk](std::shared_ptr ready_event) { + this->CPUWorker(ctx, blk, ready_event); + }, true)); return blk; }); if (ptr) { @@ -259,12 +261,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { */ template inline void CPUWorker(Context ctx, - ThreadWorkerBlock *block) { + ThreadWorkerBlock *block, + std::shared_ptr ready_event) { this->is_worker_ = true; auto* task_queue = &(block->task_queue); RunContext run_ctx{ctx, nullptr}; // execute task OprBlock* opr_block; + ready_event->signal(); while (task_queue->Pop(&opr_block)) { this->ExecuteOprBlock(run_ctx, opr_block); }