This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
df4227e
commit 438ab2b
Showing
6 changed files
with
319 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,8 @@ | |
dmlc-core | ||
mshadow | ||
config.mk | ||
|
||
# vim | ||
*.swp | ||
*.swo | ||
*.swn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#pragma once | ||
#include <list> | ||
#include <mutex> | ||
#include <condition_variable> | ||
#include <atomic> | ||
#include <thread> | ||
#include <cstdio> | ||
|
||
template<typename T> class ConcurrentBlockingQueue { | ||
const static int BUSY_LOOP = 1000; | ||
public: | ||
ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { | ||
} | ||
void Push(const T& e) { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
has_elmt_ = true; | ||
queue_.push_back(e); | ||
if (queue_.size() == 1) { | ||
cv_.notify_all(); | ||
} | ||
} | ||
bool Pop(T& rv) { | ||
for (int i = 0; i < BUSY_LOOP; i++) { | ||
if (has_elmt_) { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
if (!has_elmt_) { | ||
assert(queue_.empty()); | ||
continue; | ||
} | ||
rv = queue_.front(); | ||
This comment has been minimized.
Sorry, something went wrong. |
||
queue_.pop_front(); | ||
if (queue_.empty()) | ||
has_elmt_ = false; | ||
return false; | ||
} | ||
} | ||
{ | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
while (queue_.empty() && !exit_now_) { | ||
cv_.wait(lock); | ||
} | ||
if (!exit_now_) { | ||
rv = queue_.front(); | ||
queue_.pop_front(); | ||
if (queue_.empty()) | ||
has_elmt_ = false; | ||
return false; | ||
} else { | ||
return true; | ||
} | ||
} | ||
} | ||
std::list<T> PopAll() { | ||
std::lock_guard<std::mutex> lock(mutex_); | ||
std::list<T> rv; | ||
rv.swap(queue_); | ||
return rv; | ||
} | ||
// Call `SignalForKill` before destruction | ||
void SignalForKill() { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
exit_now_ = true; | ||
cv_.notify_all(); | ||
} | ||
size_t QueueSize() { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
return queue_.size(); | ||
} | ||
|
||
private: | ||
std::atomic<bool> has_elmt_; | ||
std::list<T> queue_; | ||
std::mutex mutex_; | ||
std::condition_variable cv_; | ||
std::atomic<bool> exit_now_; | ||
|
||
ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; | ||
ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#ifndef _SPINLOCK_XCHG_H | ||
#define _SPINLOCK_XCHG_H | ||
|
||
/* Spin lock using xchg. | ||
* Copied from http://locklessinc.com/articles/locks/ | ||
*/ | ||
|
||
/* Compile read-write barrier */ | ||
#define barrier() asm volatile("": : :"memory") | ||
|
||
/* Pause instruction to prevent excess processor bus usage */ | ||
#define cpu_relax() asm volatile("pause\n": : :"memory") | ||
|
||
static inline unsigned short xchg_8(void *ptr, unsigned char x) { | ||
__asm__ __volatile__("xchgb %0,%1" | ||
:"=r" (x) | ||
:"m" (*(volatile unsigned char *)ptr), "0" (x) | ||
:"memory"); | ||
|
||
return x; | ||
} | ||
|
||
#define BUSY 1 | ||
typedef unsigned char spinlock; | ||
|
||
#define SPINLOCK_INITIALIZER 0 | ||
|
||
static inline void spin_lock(spinlock *lock) { | ||
while (1) { | ||
if (!xchg_8(lock, BUSY)) return; | ||
|
||
while (*lock) cpu_relax(); | ||
} | ||
} | ||
|
||
static inline void spin_unlock(spinlock *lock) { | ||
barrier(); | ||
*lock = 0; | ||
} | ||
|
||
static inline int spin_trylock(spinlock *lock) { | ||
return xchg_8(lock, BUSY); | ||
} | ||
|
||
#endif /* _SPINLOCK_XCHG_H */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
#include <queue> | ||
#include <memory> | ||
#include <tuple> | ||
#include <utility> | ||
#include <atomic> | ||
#include <thread> | ||
#include <random> | ||
|
||
#include <dmlc/logging.h> | ||
#include <mxnet/dag_engine.h> | ||
#include "../common/spin_lock.h" | ||
#include "../common/concurrent_blocking_queue.h" | ||
|
||
using namespace std; | ||
|
||
namespace mxnet { | ||
|
||
#define DEFAULT_NUM_WORKER_THREADS 4 | ||
|
||
class ThreadedEngine : public DAGEngine { | ||
public: | ||
ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { | ||
for(int i = 0; i < numthreads; ++i) { | ||
worker_queues_.push_back(new ConcurrentBlockingQueue<OpDescr*>()); | ||
workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); | ||
} | ||
} | ||
~ThreadedEngine() { | ||
for(int i = 0; i < numthreads_; ++i) { | ||
worker_queues_[i]->SignalForKill(); | ||
delete worker_queues_[i]; | ||
workers_[i].join(); | ||
} | ||
} | ||
void Push(AsyncOp exec_fun, | ||
Context exec_ctx, | ||
const vector<Variable> &use_vars, | ||
const vector<Variable> &mutate_vars) override { | ||
shared_ptr<OpDescr> opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, | ||
[this] (OpDescr* o) { this->OnDepsResolved(o); } ); | ||
for( Variable v : use_vars ) { // read | ||
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here | ||
spin_lock(&vard->lock); | ||
if (vard->rw < 0) { | ||
vard->waitings.push(make_pair(opd, DepType::kRead)); | ||
} else { | ||
++vard->rw; | ||
} | ||
spin_unlock(&vard->lock); | ||
} | ||
for( Variable v : mutate_vars ) { // write | ||
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here | ||
spin_lock(&vard->lock); | ||
if (vard->rw != 0) { | ||
vard->waitings.push(make_pair(opd, DepType::kWrite)); | ||
} else { | ||
vard->rw = -1; | ||
} | ||
spin_unlock(&vard->lock); | ||
} | ||
} | ||
void Push(Op exec_fun, | ||
Context exec_ctx, | ||
const vector<Variable> &use_vars, | ||
const vector<Variable> &mutate_vars) override { | ||
this->Push([exec_fun](RunContext ctx, Callback on_complete) { | ||
exec_fun(ctx); on_complete(); | ||
}, exec_ctx, use_vars, mutate_vars); | ||
} | ||
void PushDelete(Op delete_fun, Variable var) override { | ||
// TODO | ||
this->Push([delete_fun, var] (RunContext ctx) { | ||
delete_fun(ctx); | ||
delete static_cast<VarDescr*>(var); | ||
}, Context()/* TODO exec_ctx is missing?*/, {}, {var}); | ||
} | ||
Variable NewVar() override { | ||
// in practice return a ptr to a cell | ||
// that have the info about the variable | ||
// use ptr directly instead of ID because this avoids an indirect mapping | ||
VarDescr* vd = new VarDescr; | ||
vd->lock = SPINLOCK_INITIALIZER; | ||
vd->rw = 0; | ||
return vd; | ||
} | ||
void WaitForVar(Variable var) override { | ||
// TODO | ||
} | ||
void WaitForAll() override { | ||
// TODO | ||
} | ||
private: | ||
enum class DepType { | ||
kRead = 0, | ||
kWrite, | ||
kDelete, | ||
}; | ||
struct OpDescr { | ||
AsyncOp op; | ||
Context exec_ctx; | ||
vector<Variable> read_vars; | ||
vector<Variable> write_vars; | ||
}; | ||
struct VarDescr { | ||
spinlock lock; | ||
int rw; // a semaphore-like count | ||
// if rw > 0, the variable has several readers and the number | ||
// means how many operators are currently reading it; | ||
// if rw < 0, the varaible has one writer (should be -1) | ||
queue<pair<shared_ptr<OpDescr>, DepType>> waitings; | ||
}; | ||
void TriggerWaiting(VarDescr* vard) { | ||
// ATTENTION: this function should be called with vard->lock held. | ||
CHECK(vard->rw == 0) << "the variable should be free during triggering"; | ||
if(!vard->waitings.empty()) { | ||
// pop all reads first | ||
while(vard->waitings.front().second == DepType::kRead) { | ||
vard->waitings.pop(); | ||
++vard->rw; | ||
} | ||
if (vard->rw == 0) { | ||
// if the next one is a delete | ||
// pop the next write | ||
vard->waitings.pop(); | ||
vard->rw = -1; | ||
} | ||
} | ||
} | ||
void OnOpFinished(OpDescr* opd) { | ||
CHECK(opd) << "completing a nullptr op!"; | ||
for(Variable v : opd->read_vars) { | ||
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here | ||
spin_lock(&vard->lock); | ||
CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; | ||
if(--vard->rw == 0) { | ||
TriggerWaiting(vard); | ||
} | ||
spin_unlock(&vard->lock); | ||
} | ||
for(Variable v : opd->write_vars) { | ||
VarDescr* vard = static_cast<VarDescr*>(v); // safe to cast here | ||
spin_lock(&vard->lock); | ||
CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; | ||
vard->rw = 0; | ||
TriggerWaiting(vard); | ||
spin_unlock(&vard->lock); | ||
} | ||
delete opd; // delete the operator | ||
} | ||
RunContext GetRunContext(const Context& ctx) { | ||
// TODO | ||
return RunContext(); | ||
} | ||
void OnDepsResolved(OpDescr* opd) { | ||
static default_random_engine generator; | ||
static uniform_int_distribution<int> distribution(0, numthreads_); | ||
int thrid = distribution(generator); | ||
worker_queues_[thrid]->Push(opd); | ||
} | ||
void WorkerRoutine(int thrid) { | ||
OpDescr* opd = nullptr; | ||
while(! worker_queues_[thrid]->Pop(opd)) { | ||
LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; | ||
opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); | ||
opd = nullptr; | ||
} | ||
} | ||
private: | ||
const int numthreads_; | ||
vector<ConcurrentBlockingQueue<OpDescr*>*> worker_queues_; | ||
vector<thread> workers_; | ||
}; | ||
|
||
// implements the singleton factory | ||
DAGEngine* DAGEngine::Get() { | ||
static ThreadedEngine engine; | ||
return &engine; | ||
} | ||
} // namespace mxnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#include <mxnet/dag_engine.h> | ||
|
||
using namespace std; | ||
using namespace mxnet; | ||
|
||
int main() { | ||
DAGEngine* engine = DAGEngine::Get(); | ||
return 0; | ||
} |
rv = std::move(queue_.front()) might be more efficient