diff --git a/.gitignore b/.gitignore index a63de96ac6d6..3292dfc4e309 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,8 @@ dmlc-core mshadow config.mk + +# vim +*.swp +*.swo +*.swn diff --git a/Makefile b/Makefile index b159e0bc9429..49dde6b7b687 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,7 @@ ifneq ($(ADD_LDFLAGS), NONE) endif OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o -OBJCXX11 = engine.o narray.o +OBJCXX11 = engine.o narray.o threaded_engine.o CUOBJ = narray_op_gpu.o operator_gpu.o LIB_DEP = $(DMLC_CORE)/libdmlc.a @@ -62,6 +62,7 @@ $(DMLC_CORE)/libdmlc.a: storage.o: src/storage/storage.cc engine.o: src/dag_engine/simple_engine.cc +threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h new file mode 100644 index 000000000000..aab39895b119 --- /dev/null +++ b/src/common/concurrent_blocking_queue.h @@ -0,0 +1,79 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +template class ConcurrentBlockingQueue { + const static int BUSY_LOOP = 1000; + public: + ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { + } + void Push(const T& e) { + std::lock_guard lock(mutex_); + has_elmt_ = true; + queue_.push_back(e); + if (queue_.size() == 1) { + cv_.notify_all(); + } + } + bool Pop(T& rv) { + for (int i = 0; i < BUSY_LOOP; i++) { + if (has_elmt_) { + std::lock_guard lock(mutex_); + if (!has_elmt_) { + assert(queue_.empty()); + continue; + } + rv = queue_.front(); + queue_.pop_front(); + if (queue_.empty()) + has_elmt_ = false; + return false; + } + } + { + std::unique_lock lock(mutex_); + while (queue_.empty() && !exit_now_) { + cv_.wait(lock); + } + if (!exit_now_) { + rv = queue_.front(); + queue_.pop_front(); + if (queue_.empty()) + has_elmt_ = false; + return false; + } else { + return true; + } + } + } + std::list PopAll() { + std::lock_guard lock(mutex_); + std::list rv; + rv.swap(queue_); + return rv; + } + // Call `SignalForKill` before destruction + void SignalForKill() { + std::unique_lock lock(mutex_); + exit_now_ = true; + cv_.notify_all(); + } + size_t QueueSize() { + std::unique_lock lock(mutex_); + return queue_.size(); + } + + private: + std::atomic has_elmt_; + std::list queue_; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic exit_now_; + + ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; + ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; +}; diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h new file mode 100644 index 000000000000..5a0cc3f786e6 --- /dev/null +++ b/src/common/spin_lock.h @@ -0,0 +1,45 @@ +#ifndef _SPINLOCK_XCHG_H +#define _SPINLOCK_XCHG_H + +/* Spin lock using xchg. + * Copied from http://locklessinc.com/articles/locks/ + */ + +/* Compile read-write barrier */ +#define barrier() asm volatile("": : :"memory") + +/* Pause instruction to prevent excess processor bus usage */ +#define cpu_relax() asm volatile("pause\n": : :"memory") + +static inline unsigned short xchg_8(void *ptr, unsigned char x) { + __asm__ __volatile__("xchgb %0,%1" + :"=r" (x) + :"m" (*(volatile unsigned char *)ptr), "0" (x) + :"memory"); + + return x; +} + +#define BUSY 1 +typedef unsigned char spinlock; + +#define SPINLOCK_INITIALIZER 0 + +static inline void spin_lock(spinlock *lock) { + while (1) { + if (!xchg_8(lock, BUSY)) return; + + while (*lock) cpu_relax(); + } +} + +static inline void spin_unlock(spinlock *lock) { + barrier(); + *lock = 0; +} + +static inline int spin_trylock(spinlock *lock) { + return xchg_8(lock, BUSY); +} + +#endif /* _SPINLOCK_XCHG_H */ diff --git a/src/dag_engine/threaded_engine.cc b/src/dag_engine/threaded_engine.cc new file mode 100644 index 000000000000..143b5e72f413 --- /dev/null +++ b/src/dag_engine/threaded_engine.cc @@ -0,0 +1,179 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "../common/spin_lock.h" +#include "../common/concurrent_blocking_queue.h" + +using namespace std; + +namespace mxnet { + +#define DEFAULT_NUM_WORKER_THREADS 4 + +class ThreadedEngine : public DAGEngine { + public: + ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { + for(int i = 0; i < numthreads; ++i) { + worker_queues_.push_back(new ConcurrentBlockingQueue()); + workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); + } + } + ~ThreadedEngine() { + for(int i = 0; i < numthreads_; ++i) { + worker_queues_[i]->SignalForKill(); + delete worker_queues_[i]; + workers_[i].join(); + } + } + void Push(AsyncOp exec_fun, + Context exec_ctx, + const vector &use_vars, + const vector &mutate_vars) override { + shared_ptr opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, + [this] (OpDescr* o) { this->OnDepsResolved(o); } ); + for( Variable v : use_vars ) { // read + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + if (vard->rw < 0) { + vard->waitings.push(make_pair(opd, DepType::kRead)); + } else { + ++vard->rw; + } + spin_unlock(&vard->lock); + } + for( Variable v : mutate_vars ) { // write + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + if (vard->rw != 0) { + vard->waitings.push(make_pair(opd, DepType::kWrite)); + } else { + vard->rw = -1; + } + spin_unlock(&vard->lock); + } + } + void Push(Op exec_fun, + Context exec_ctx, + const vector &use_vars, + const vector &mutate_vars) override { + this->Push([exec_fun](RunContext ctx, Callback on_complete) { + exec_fun(ctx); on_complete(); + }, exec_ctx, use_vars, mutate_vars); + } + void PushDelete(Op delete_fun, Variable var) override { + // TODO + this->Push([delete_fun, var] (RunContext ctx) { + delete_fun(ctx); + delete static_cast(var); + }, Context()/* TODO exec_ctx is missing?*/, {}, {var}); + } + Variable NewVar() override { + // in practice return a ptr to a cell + // that have the info about the variable + // use ptr directly instead of ID because this avoids an indirect mapping + VarDescr* vd = new VarDescr; + vd->lock = SPINLOCK_INITIALIZER; + vd->rw = 0; + return vd; + } + void WaitForVar(Variable var) override { + // TODO + } + void WaitForAll() override { + // TODO + } + private: + enum class DepType { + kRead = 0, + kWrite, + kDelete, + }; + struct OpDescr { + AsyncOp op; + Context exec_ctx; + vector read_vars; + vector write_vars; + }; + struct VarDescr { + spinlock lock; + int rw; // a semaphore-like count + // if rw > 0, the variable has several readers and the number + // means how many operators are currently reading it; + // if rw < 0, the varaible has one writer (should be -1) + queue, DepType>> waitings; + }; + void TriggerWaiting(VarDescr* vard) { + // ATTENTION: this function should be called with vard->lock held. + CHECK(vard->rw == 0) << "the variable should be free during triggering"; + if(!vard->waitings.empty()) { + // pop all reads first + while(vard->waitings.front().second == DepType::kRead) { + vard->waitings.pop(); + ++vard->rw; + } + if (vard->rw == 0) { + // if the next one is a delete + // pop the next write + vard->waitings.pop(); + vard->rw = -1; + } + } + } + void OnOpFinished(OpDescr* opd) { + CHECK(opd) << "completing a nullptr op!"; + for(Variable v : opd->read_vars) { + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; + if(--vard->rw == 0) { + TriggerWaiting(vard); + } + spin_unlock(&vard->lock); + } + for(Variable v : opd->write_vars) { + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; + vard->rw = 0; + TriggerWaiting(vard); + spin_unlock(&vard->lock); + } + delete opd; // delete the operator + } + RunContext GetRunContext(const Context& ctx) { + // TODO + return RunContext(); + } + void OnDepsResolved(OpDescr* opd) { + static default_random_engine generator; + static uniform_int_distribution distribution(0, numthreads_); + int thrid = distribution(generator); + worker_queues_[thrid]->Push(opd); + } + void WorkerRoutine(int thrid) { + OpDescr* opd = nullptr; + while(! worker_queues_[thrid]->Pop(opd)) { + LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; + opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); + opd = nullptr; + } + } + private: + const int numthreads_; + vector*> worker_queues_; + vector workers_; +}; + +// implements the singleton factory +DAGEngine* DAGEngine::Get() { + static ThreadedEngine engine; + return &engine; +} +} // namespace mxnet diff --git a/test/test_threaded_engine.cc b/test/test_threaded_engine.cc new file mode 100644 index 000000000000..40dea029cf6e --- /dev/null +++ b/test/test_threaded_engine.cc @@ -0,0 +1,9 @@ +#include + +using namespace std; +using namespace mxnet; + +int main() { + DAGEngine* engine = DAGEngine::Get(); + return 0; +}