Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
threaded engine draft
Browse files Browse the repository at this point in the history
  • Loading branch information
jermainewang committed Jun 29, 2015
1 parent df4227e commit 438ab2b
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 1 deletion.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@
dmlc-core
mshadow
config.mk

# vim
*.swp
*.swo
*.swn
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ifneq ($(ADD_LDFLAGS), NONE)
endif

OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o
OBJCXX11 = engine.o narray.o
OBJCXX11 = engine.o narray.o threaded_engine.o
CUOBJ = narray_op_gpu.o operator_gpu.o

LIB_DEP = $(DMLC_CORE)/libdmlc.a
Expand All @@ -62,6 +62,7 @@ $(DMLC_CORE)/libdmlc.a:

storage.o: src/storage/storage.cc
engine.o: src/dag_engine/simple_engine.cc
threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h
narray.o: src/narray/narray.cc
narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h
narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h
Expand Down
79 changes: 79 additions & 0 deletions src/common/concurrent_blocking_queue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pragma once
#include <list>
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <thread>
#include <cstdio>

template<typename T> class ConcurrentBlockingQueue {
const static int BUSY_LOOP = 1000;
public:
ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) {
}
void Push(const T& e) {
std::lock_guard<std::mutex> lock(mutex_);
has_elmt_ = true;
queue_.push_back(e);
if (queue_.size() == 1) {
cv_.notify_all();
}
}
bool Pop(T& rv) {
for (int i = 0; i < BUSY_LOOP; i++) {
if (has_elmt_) {
std::lock_guard<std::mutex> lock(mutex_);
if (!has_elmt_) {
assert(queue_.empty());
continue;
}
rv = queue_.front();

This comment has been minimized.

Copy link
@mli

mli Jul 1, 2015

Member

rv = std::move(queue_.front()) might be more efficient

queue_.pop_front();
if (queue_.empty())
has_elmt_ = false;
return false;
}
}
{
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.empty() && !exit_now_) {
cv_.wait(lock);
}
if (!exit_now_) {
rv = queue_.front();
queue_.pop_front();
if (queue_.empty())
has_elmt_ = false;
return false;
} else {
return true;
}
}
}
std::list<T> PopAll() {
std::lock_guard<std::mutex> lock(mutex_);
std::list<T> rv;
rv.swap(queue_);
return rv;
}
// Call `SignalForKill` before destruction
void SignalForKill() {
std::unique_lock<std::mutex> lock(mutex_);
exit_now_ = true;
cv_.notify_all();
}
size_t QueueSize() {
std::unique_lock<std::mutex> lock(mutex_);
return queue_.size();
}

private:
std::atomic<bool> has_elmt_;
std::list<T> queue_;
std::mutex mutex_;
std::condition_variable cv_;
std::atomic<bool> exit_now_;

ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete;
ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete;
};
45 changes: 45 additions & 0 deletions src/common/spin_lock.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef _SPINLOCK_XCHG_H
#define _SPINLOCK_XCHG_H

/* Spin lock using xchg.
* Copied from http://locklessinc.com/articles/locks/
*/

/* Compile read-write barrier */
#define barrier() asm volatile("": : :"memory")

/* Pause instruction to prevent excess processor bus usage */
#define cpu_relax() asm volatile("pause\n": : :"memory")

static inline unsigned short xchg_8(void *ptr, unsigned char x) {
__asm__ __volatile__("xchgb %0,%1"
:"=r" (x)
:"m" (*(volatile unsigned char *)ptr), "0" (x)
:"memory");

return x;
}

#define BUSY 1
typedef unsigned char spinlock;

#define SPINLOCK_INITIALIZER 0

static inline void spin_lock(spinlock *lock) {
while (1) {
if (!xchg_8(lock, BUSY)) return;

while (*lock) cpu_relax();
}
}

static inline void spin_unlock(spinlock *lock) {
barrier();
*lock = 0;
}

static inline int spin_trylock(spinlock *lock) {
return xchg_8(lock, BUSY);
}

#endif /* _SPINLOCK_XCHG_H */
179 changes: 179 additions & 0 deletions src/dag_engine/threaded_engine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#include <queue>
#include <memory>
#include <tuple>
#include <utility>
#include <atomic>
#include <thread>
#include <random>

#include <dmlc/logging.h>
#include <mxnet/dag_engine.h>
#include "../common/spin_lock.h"
#include "../common/concurrent_blocking_queue.h"

using namespace std;

namespace mxnet {

#define DEFAULT_NUM_WORKER_THREADS 4

class ThreadedEngine : public DAGEngine {
public:
ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) {
for(int i = 0; i < numthreads; ++i) {
worker_queues_.push_back(new ConcurrentBlockingQueue<OpDescr*>());
workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i);
}
}
~ThreadedEngine() {
for(int i = 0; i < numthreads_; ++i) {
worker_queues_[i]->SignalForKill();
delete worker_queues_[i];
workers_[i].join();
}
}
void Push(AsyncOp exec_fun,
Context exec_ctx,
const vector<Variable> &use_vars,
const vector<Variable> &mutate_vars) override {
shared_ptr<OpDescr> opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars},
[this] (OpDescr* o) { this->OnDepsResolved(o); } );
for( Variable v : use_vars ) { // read
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here
spin_lock(&vard->lock);
if (vard->rw < 0) {
vard->waitings.push(make_pair(opd, DepType::kRead));
} else {
++vard->rw;
}
spin_unlock(&vard->lock);
}
for( Variable v : mutate_vars ) { // write
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here
spin_lock(&vard->lock);
if (vard->rw != 0) {
vard->waitings.push(make_pair(opd, DepType::kWrite));
} else {
vard->rw = -1;
}
spin_unlock(&vard->lock);
}
}
void Push(Op exec_fun,
Context exec_ctx,
const vector<Variable> &use_vars,
const vector<Variable> &mutate_vars) override {
this->Push([exec_fun](RunContext ctx, Callback on_complete) {
exec_fun(ctx); on_complete();
}, exec_ctx, use_vars, mutate_vars);
}
void PushDelete(Op delete_fun, Variable var) override {
// TODO
this->Push([delete_fun, var] (RunContext ctx) {
delete_fun(ctx);
delete static_cast<VarDescr*>(var);
}, Context()/* TODO exec_ctx is missing?*/, {}, {var});
}
Variable NewVar() override {
// in practice return a ptr to a cell
// that have the info about the variable
// use ptr directly instead of ID because this avoids an indirect mapping
VarDescr* vd = new VarDescr;
vd->lock = SPINLOCK_INITIALIZER;
vd->rw = 0;
return vd;
}
void WaitForVar(Variable var) override {
// TODO
}
void WaitForAll() override {
// TODO
}
private:
enum class DepType {
kRead = 0,
kWrite,
kDelete,
};
struct OpDescr {
AsyncOp op;
Context exec_ctx;
vector<Variable> read_vars;
vector<Variable> write_vars;
};
struct VarDescr {
spinlock lock;
int rw; // a semaphore-like count
// if rw > 0, the variable has several readers and the number
// means how many operators are currently reading it;
// if rw < 0, the varaible has one writer (should be -1)
queue<pair<shared_ptr<OpDescr>, DepType>> waitings;
};
void TriggerWaiting(VarDescr* vard) {
// ATTENTION: this function should be called with vard->lock held.
CHECK(vard->rw == 0) << "the variable should be free during triggering";
if(!vard->waitings.empty()) {
// pop all reads first
while(vard->waitings.front().second == DepType::kRead) {
vard->waitings.pop();
++vard->rw;
}
if (vard->rw == 0) {
// if the next one is a delete
// pop the next write
vard->waitings.pop();
vard->rw = -1;
}
}
}
void OnOpFinished(OpDescr* opd) {
CHECK(opd) << "completing a nullptr op!";
for(Variable v : opd->read_vars) {
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here
spin_lock(&vard->lock);
CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw;
if(--vard->rw == 0) {
TriggerWaiting(vard);
}
spin_unlock(&vard->lock);
}
for(Variable v : opd->write_vars) {
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here
spin_lock(&vard->lock);
CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw;
vard->rw = 0;
TriggerWaiting(vard);
spin_unlock(&vard->lock);
}
delete opd; // delete the operator
}
RunContext GetRunContext(const Context& ctx) {
// TODO
return RunContext();
}
void OnDepsResolved(OpDescr* opd) {
static default_random_engine generator;
static uniform_int_distribution<int> distribution(0, numthreads_);
int thrid = distribution(generator);
worker_queues_[thrid]->Push(opd);
}
void WorkerRoutine(int thrid) {
OpDescr* opd = nullptr;
while(! worker_queues_[thrid]->Pop(opd)) {
LOG(INFO) << "worker thread #" << thrid << " got operator " << opd;
opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); });
opd = nullptr;
}
}
private:
const int numthreads_;
vector<ConcurrentBlockingQueue<OpDescr*>*> worker_queues_;
vector<thread> workers_;
};

// implements the singleton factory
DAGEngine* DAGEngine::Get() {
static ThreadedEngine engine;
return &engine;
}
} // namespace mxnet
9 changes: 9 additions & 0 deletions test/test_threaded_engine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include <mxnet/dag_engine.h>

using namespace std;
using namespace mxnet;

int main() {
DAGEngine* engine = DAGEngine::Get();
return 0;
}

0 comments on commit 438ab2b

Please sign in to comment.