From 1983061ab2ef71d3504611a93ff1b116a2c1cb3a Mon Sep 17 00:00:00 2001 From: Hyunsu Philip Cho Date: Thu, 7 Mar 2019 01:04:54 +0000 Subject: [PATCH] Fix #505: data race in threaded CSV parser --- include/dmlc/threadediter.h | 34 ++++++++++++++++++--------------- src/data/text_parser.h | 38 ++++++++++++++++++++----------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/include/dmlc/threadediter.h b/include/dmlc/threadediter.h index c920156b23..8115b8a90e 100644 --- a/include/dmlc/threadediter.h +++ b/include/dmlc/threadediter.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include "./data.h" #include "./logging.h" @@ -222,7 +223,7 @@ class ThreadedIter : public DataIter { /*! \brief thread that runs the producer */ std::thread *producer_thread_; /*! \brief whether produce ends */ - bool produce_end_; + std::atomic produce_end_; /*! \brief maximum queue size */ size_t max_capacity_; /*! \brief internal mutex */ @@ -301,7 +302,7 @@ inline void ThreadedIter::Init(std::function next, std::function beforefirst) { producer_sig_ = kProduce; producer_sig_processed_ = false; - produce_end_ = false; + produce_end_.store(false, std::memory_order_release); ClearException(); // procedure running in prodcuer // run producer thread @@ -315,8 +316,9 @@ inline void ThreadedIter::Init(std::function next, ++this->nwait_producer_; producer_cond_.wait(lock, [this]() { if (producer_sig_ == kProduce) { - bool ret = !produce_end_ && (queue_.size() < max_capacity_ || - free_cells_.size() != 0); + bool ret = !produce_end_.load(std::memory_order_acquire) + && (queue_.size() < max_capacity_ || + free_cells_.size() != 0); return ret; } else { return true; @@ -337,7 +339,7 @@ inline void ThreadedIter::Init(std::function next, queue_.pop(); } // reset the state - produce_end_ = false; + produce_end_.store(false, std::memory_order_release); producer_sig_processed_ = true; producer_sig_ = kProduce; // notify consumer that all the process as been done. @@ -348,19 +350,19 @@ inline void ThreadedIter::Init(std::function next, // destroy the thread DCHECK(producer_sig_ == kDestroy); producer_sig_processed_ = true; - produce_end_ = true; + produce_end_.store(true, std::memory_order_release); consumer_cond_.notify_all(); return; } } // end of lock scope // now without lock - produce_end_ = !next(&cell); - DCHECK(cell != NULL || produce_end_); + produce_end_.store(!next(&cell), std::memory_order_release); + DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire)); bool notify; { // lockscope std::lock_guard lock(mutex_); - if (!produce_end_) { + if (!produce_end_.load(std::memory_order_acquire)) { queue_.push(cell); } else { if (cell != NULL) @@ -388,12 +390,12 @@ inline void ThreadedIter::Init(std::function next, free_cells_.push(queue_.front()); queue_.pop(); } - produce_end_ = true; + produce_end_.store(true, std::memory_order_release); producer_sig_processed_ = true; lock.unlock(); consumer_cond_.notify_all(); } else if (producer_sig_ == kProduce) { - produce_end_ = true; + produce_end_.store(true, std::memory_order_release); next_notify = nwait_consumer_ != 0; lock.unlock(); if (next_notify) @@ -417,12 +419,14 @@ inline bool ThreadedIter::Next(DType **out_dptr) { << "Make sure you call BeforeFirst not inconcurrent with Next!"; ++nwait_consumer_; consumer_cond_.wait(lock, - [this]() { return queue_.size() != 0 || produce_end_; }); + [this]() { return queue_.size() != 0 + || produce_end_.load(std::memory_order_acquire); }); --nwait_consumer_; if (queue_.size() != 0) { *out_dptr = queue_.front(); queue_.pop(); - bool notify = nwait_producer_ != 0 && !produce_end_; + bool notify = nwait_producer_ != 0 + && !produce_end_.load(std::memory_order_acquire); lock.unlock(); if (notify) producer_cond_.notify_one(); @@ -430,7 +434,7 @@ inline bool ThreadedIter::Next(DType **out_dptr) { ThrowExceptionIfSet(); return true; } else { - CHECK(produce_end_); + CHECK(produce_end_.load(std::memory_order_acquire)); lock.unlock(); ThrowExceptionIfSet(); @@ -446,7 +450,7 @@ inline void ThreadedIter::Recycle(DType **inout_dptr) { std::lock_guard lock(mutex_); free_cells_.push(*inout_dptr); *inout_dptr = NULL; - notify = nwait_producer_ != 0 && !produce_end_; + notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire); } if (notify) producer_cond_.notify_one(); diff --git a/src/data/text_parser.h b/src/data/text_parser.h index 68ee552fe8..97524b7f37 100644 --- a/src/data/text_parser.h +++ b/src/data/text_parser.h @@ -117,23 +117,27 @@ inline bool TextParserBase::FillData( bytes_read_ += chunk.size; CHECK_NE(chunk.size, 0U); const char *head = reinterpret_cast(chunk.dptr); -#pragma omp parallel num_threads(nthread) - { - omp_exc_.Run([&] { - // threadid - int tid = omp_get_thread_num(); - size_t nstep = (chunk.size + nthread - 1) / nthread; - size_t sbegin = std::min(tid * nstep, chunk.size); - size_t send = std::min((tid + 1) * nstep, chunk.size); - const char *pbegin = BackFindEndLine(head + sbegin, head); - const char *pend; - if (tid + 1 == nthread) { - pend = head + send; - } else { - pend = BackFindEndLine(head + send, head); - } - ParseBlock(pbegin, pend, &(*data)[tid]); - }); + + std::vector threads; + for (int tid = 0; tid < nthread; ++tid) { + threads.push_back(std::thread([&chunk, head, data, nthread, tid, this] { + this->omp_exc_.Run([&] { + size_t nstep = (chunk.size + nthread - 1) / nthread; + size_t sbegin = std::min(tid * nstep, chunk.size); + size_t send = std::min((tid + 1) * nstep, chunk.size); + const char *pbegin = BackFindEndLine(head + sbegin, head); + const char *pend; + if (tid + 1 == nthread) { + pend = head + send; + } else { + pend = BackFindEndLine(head + send, head); + } + ParseBlock(pbegin, pend, &(*data)[tid]); + }); + })); + } + for (int i = 0; i < nthread; ++i) { + threads[i].join(); } omp_exc_.Rethrow();