From bd36bbc1c2a49924385489bcdbe63c71dbf8c559 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 20 Sep 2015 14:37:45 -0700 Subject: [PATCH] [BUGFIX] Engine CompleteWriteDependency. to_delete_ stale value. --- src/engine/threaded_engine.cc | 17 ++++++++++------- src/io/iter_prefetcher.h | 22 +++++++++++----------- tests/python/train/test_mlp.py | 2 ++ 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 1b2ccdec796b..0889482535e9 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -83,12 +83,20 @@ void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { template bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { + // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; - bool trigger_write = false; + bool trigger_write = false, to_delete = false; { - // this is lock scope std::lock_guard lock{m_}; assert(ready_to_read_ == false); + // really delete + if (to_delete_) { + VersionedVarBlock *head = pending_write_->next; + VersionedVarBlock::Delete(pending_write_); + assert(head->next == nullptr); + VersionedVarBlock::Delete(head); + return true; + } // detach pending write old_pending_write = pending_write_; // search for chains to trigger @@ -119,11 +127,6 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // So it is safe to modify these VersionedVarBlock *cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); - if (to_delete_) { - assert(cur_head->next == nullptr); - VersionedVarBlock::Delete(cur_head); - return true; - } // dispatch all the events while (cur_head != end_of_read_chain) { if (--cur_head->trigger->wait == 0) { diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 01674f1f17e2..ddf0f053f34b 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -115,22 +115,22 @@ class PrefetcherIter : public IIterator { } virtual bool Next(void) { if (ready_batches_.size() == param_.prefetch_buffer) { - TBlobBatch* old_batch = ready_batches_.front(); - for (size_t i = 0; i < old_batch->data.size(); i++) { - NDArray old_ndarray = ready_ndarrays_.front(); - old_ndarray.WaitToWrite(); - ready_ndarrays_.pop(); - } - iter_.Recycle(&old_batch); - ready_batches_.pop(); + TBlobBatch* old_batch = ready_batches_.front(); + for (size_t i = 0; i < old_batch->data.size(); i++) { + NDArray old_ndarray = ready_ndarrays_.front(); + old_ndarray.WaitToWrite(); + ready_ndarrays_.pop(); + } + iter_.Recycle(&old_batch); + ready_batches_.pop(); } TBlobBatch* next_batch = NULL; if (!iter_.Next(&next_batch)) return false; out_.data.clear(); // copy the batch - for (size_t i = 0; i < next_batch->data.size(); i++) { - out_.data.push_back(NDArray(next_batch->data[i], mshadow::cpu::kDevMask)); - ready_ndarrays_.push(out_.data[i]); + for (size_t i = 0; i < next_batch->data.size(); ++i) { + out_.data.push_back(NDArray(next_batch->data[i], 0)); + ready_ndarrays_.push(out_.data[i]); } // push the narrays and batch into the queue ready_batches_.push(next_batch); diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index eb3502de958a..e863b1164258 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -46,7 +46,9 @@ def test_mlp(): model.fit(X=train_dataiter, eval_data=val_dataiter, iter_end_callback=mx.model.do_checkpoint(prefix)) + logging.info('Finish fit...') prob = model.predict(val_dataiter) + logging.info('Finish predict...') val_dataiter.reset() y = np.concatenate([label.asnumpy() for _, label in val_dataiter]).astype('int') py = np.argmax(prob, axis=1)