Skip to content

Commit

Permalink
common: delay ps initialization for better RDMA compatibility (#91)
Browse files Browse the repository at this point in the history
* delay ps initialization for better RDMA compatibility

* common: quick fix BytePSGlobal::GetOrInitPS()

* common: move ps init logic

* 3rdparty: fix rdma runtime bug
  • Loading branch information
bobzhuyb authored and ymjiang committed Sep 23, 2019
1 parent b9a5ba6 commit f18f1a8
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/ps-lite
Submodule ps-lite updated 3 files
+4 −4 Makefile
+6 −1 src/customer.cc
+53 −96 src/rdma_van.h
29 changes: 17 additions & 12 deletions byteps/common/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,6 @@ void BytePSGlobal::Init() {

_shm_obj = std::make_shared<BytePSSharedMemory>(); // share memory obj

if (IsDistributed() &&
_my_role ==
BytePSRole::LOCAL_ROOT) { // only the root need to do networking
// init low-level ps implementation
_ps = new ps::KVWorker<char>(0, 0);
ps::StartAsync(0, "byteps\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}

// Set to associated GPU
CUDA_CALL(cudaSetDevice(_local_rank));

Expand Down Expand Up @@ -200,6 +188,23 @@ void BytePSGlobal::Init() {
return;
}

ps::KVWorker<char>* BytePSGlobal::GetOrInitPS() {
// we reuse _init_mutex, because BytePS should have been inited
std::lock_guard<std::mutex> lock(_init_mutex);
if (!_ps && IsDistributed() &&
_my_role ==
BytePSRole::LOCAL_ROOT) { // only the root needs networking
// init low-level ps implementation
_ps = new ps::KVWorker<char>(0, 0);
ps::StartAsync(0, "byteps\0");
if (!ps::Postoffice::Get()->is_recovery()) {
ps::Postoffice::Get()->Barrier(
0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
return _ps;
}

void BytePSGlobal::Start(const std::vector<LoopFunction>& func) {
// Start background threads
for (size_t i = 0; i < func.size(); i++) {
Expand Down
1 change: 1 addition & 0 deletions byteps/common/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class BytePSGlobal {
static BytePSScheduledQueue* GetScheduledQueue(QueueType queueType);
static void CreateScheduledQueue(QueueType queueType);
static ps::KVWorker<char>* GetPS() { return _ps; }
static ps::KVWorker<char>* GetOrInitPS();

static bool IsTensorDeclared(const std::string& name);
static ps::Key GetKeyFromName(const std::string& name);
Expand Down
6 changes: 3 additions & 3 deletions byteps/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,15 @@ void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff) {
int len = ((size - accumulated) > bound) ? bound : (size - accumulated);

if (BytePSGlobal::IsDistributed() && BytePSGlobal::IsRootDevice()) {
auto ps = BytePSGlobal::GetOrInitPS();
// encode the key for pskv scattering
auto &pskv = BytePSGlobal::EncodeDefaultKey(key, len);
// false means not to delete data when SArray is deleted
ps::SArray<char> vals(data + accumulated, len, false);
// cmd type
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
// blocking push, also as a global barrier
BytePSGlobal::GetPS()->Wait(
BytePSGlobal::GetPS()->ZPush(pskv.keys, vals, pskv.lens, cmd));
// blocking push, also as a global barrirer
ps->Wait(ps->ZPush(pskv.keys, vals, pskv.lens, cmd));
}

accumulated += len;
Expand Down

0 comments on commit f18f1a8

Please sign in to comment.