Skip to content

Commit

Permalink
Fix dmlc#505: data race in threaded CSV parser
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Mar 7, 2019
1 parent 7e2a1ec commit 1983061
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 32 deletions.
34 changes: 19 additions & 15 deletions include/dmlc/threadediter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <functional>
#include <mutex>
#include <queue>
#include <atomic>
#include <thread>
#include "./data.h"
#include "./logging.h"
Expand Down Expand Up @@ -222,7 +223,7 @@ class ThreadedIter : public DataIter<DType> {
/*! \brief thread that runs the producer */
std::thread *producer_thread_;
/*! \brief whether produce ends */
bool produce_end_;
std::atomic<bool> produce_end_;
/*! \brief maximum queue size */
size_t max_capacity_;
/*! \brief internal mutex */
Expand Down Expand Up @@ -301,7 +302,7 @@ inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
std::function<void()> beforefirst) {
producer_sig_ = kProduce;
producer_sig_processed_ = false;
produce_end_ = false;
produce_end_.store(false, std::memory_order_release);
ClearException();
// procedure running in prodcuer
// run producer thread
Expand All @@ -315,8 +316,9 @@ inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
++this->nwait_producer_;
producer_cond_.wait(lock, [this]() {
if (producer_sig_ == kProduce) {
bool ret = !produce_end_ && (queue_.size() < max_capacity_ ||
free_cells_.size() != 0);
bool ret = !produce_end_.load(std::memory_order_acquire)
&& (queue_.size() < max_capacity_ ||
free_cells_.size() != 0);
return ret;
} else {
return true;
Expand All @@ -337,7 +339,7 @@ inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
queue_.pop();
}
// reset the state
produce_end_ = false;
produce_end_.store(false, std::memory_order_release);
producer_sig_processed_ = true;
producer_sig_ = kProduce;
// notify consumer that all the process as been done.
Expand All @@ -348,19 +350,19 @@ inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
// destroy the thread
DCHECK(producer_sig_ == kDestroy);
producer_sig_processed_ = true;
produce_end_ = true;
produce_end_.store(true, std::memory_order_release);
consumer_cond_.notify_all();
return;
}
} // end of lock scope
// now without lock
produce_end_ = !next(&cell);
DCHECK(cell != NULL || produce_end_);
produce_end_.store(!next(&cell), std::memory_order_release);
DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire));
bool notify;
{
// lockscope
std::lock_guard<std::mutex> lock(mutex_);
if (!produce_end_) {
if (!produce_end_.load(std::memory_order_acquire)) {
queue_.push(cell);
} else {
if (cell != NULL)
Expand Down Expand Up @@ -388,12 +390,12 @@ inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
free_cells_.push(queue_.front());
queue_.pop();
}
produce_end_ = true;
produce_end_.store(true, std::memory_order_release);
producer_sig_processed_ = true;
lock.unlock();
consumer_cond_.notify_all();
} else if (producer_sig_ == kProduce) {
produce_end_ = true;
produce_end_.store(true, std::memory_order_release);
next_notify = nwait_consumer_ != 0;
lock.unlock();
if (next_notify)
Expand All @@ -417,20 +419,22 @@ inline bool ThreadedIter<DType>::Next(DType **out_dptr) {
<< "Make sure you call BeforeFirst not inconcurrent with Next!";
++nwait_consumer_;
consumer_cond_.wait(lock,
[this]() { return queue_.size() != 0 || produce_end_; });
[this]() { return queue_.size() != 0
|| produce_end_.load(std::memory_order_acquire); });
--nwait_consumer_;
if (queue_.size() != 0) {
*out_dptr = queue_.front();
queue_.pop();
bool notify = nwait_producer_ != 0 && !produce_end_;
bool notify = nwait_producer_ != 0
&& !produce_end_.load(std::memory_order_acquire);
lock.unlock();
if (notify)
producer_cond_.notify_one();

ThrowExceptionIfSet();
return true;
} else {
CHECK(produce_end_);
CHECK(produce_end_.load(std::memory_order_acquire));
lock.unlock();

ThrowExceptionIfSet();
Expand All @@ -446,7 +450,7 @@ inline void ThreadedIter<DType>::Recycle(DType **inout_dptr) {
std::lock_guard<std::mutex> lock(mutex_);
free_cells_.push(*inout_dptr);
*inout_dptr = NULL;
notify = nwait_producer_ != 0 && !produce_end_;
notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire);
}
if (notify)
producer_cond_.notify_one();
Expand Down
38 changes: 21 additions & 17 deletions src/data/text_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,27 @@ inline bool TextParserBase<IndexType, DType>::FillData(
bytes_read_ += chunk.size;
CHECK_NE(chunk.size, 0U);
const char *head = reinterpret_cast<char *>(chunk.dptr);
#pragma omp parallel num_threads(nthread)
{
omp_exc_.Run([&] {
// threadid
int tid = omp_get_thread_num();
size_t nstep = (chunk.size + nthread - 1) / nthread;
size_t sbegin = std::min(tid * nstep, chunk.size);
size_t send = std::min((tid + 1) * nstep, chunk.size);
const char *pbegin = BackFindEndLine(head + sbegin, head);
const char *pend;
if (tid + 1 == nthread) {
pend = head + send;
} else {
pend = BackFindEndLine(head + send, head);
}
ParseBlock(pbegin, pend, &(*data)[tid]);
});

std::vector<std::thread> threads;
for (int tid = 0; tid < nthread; ++tid) {
threads.push_back(std::thread([&chunk, head, data, nthread, tid, this] {
this->omp_exc_.Run([&] {
size_t nstep = (chunk.size + nthread - 1) / nthread;
size_t sbegin = std::min(tid * nstep, chunk.size);
size_t send = std::min((tid + 1) * nstep, chunk.size);
const char *pbegin = BackFindEndLine(head + sbegin, head);
const char *pend;
if (tid + 1 == nthread) {
pend = head + send;
} else {
pend = BackFindEndLine(head + send, head);
}
ParseBlock(pbegin, pend, &(*data)[tid]);
});
}));
}
for (int i = 0; i < nthread; ++i) {
threads[i].join();
}
omp_exc_.Rethrow();

Expand Down

0 comments on commit 1983061

Please sign in to comment.