Skip to content

Commit

Permalink
add async training support (apache#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymjiang committed Sep 22, 2019
1 parent adadd5d commit 9dfa1d8
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,14 @@ class KVStoreDistServer {
std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
ps_server_->set_request_handle(
std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
sync_mode_ = dmlc::GetEnv("IS_PS_SYNC_MODE", true);
gradient_compression_ = std::make_shared<GradientCompression>();
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
update_buf_wait_ = dmlc::GetEnv("PS_ENABLE_GRADIENT_WAIT", false);

sync_mode_ = !dmlc::GetEnv("BYTEPS_ENABLE_ASYNC", false);
if (!sync_mode_) {
LOG(INFO) << "BytePS server is enabled asynchronous training";
}
}

~KVStoreDistServer() {
Expand Down Expand Up @@ -218,7 +222,7 @@ class KVStoreDistServer {
exec_.Stop();
break;
case CommandType::kSyncMode:
sync_mode_ = true;
CHECK(0) << "kSyncMode is not available now";
break;
case CommandType::kSetGradientCompression:
gradient_compression_->DecodeParams(recved.body);
Expand Down Expand Up @@ -365,7 +369,10 @@ class KVStoreDistServer {
auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array;
// NOTE: not sure whether we need this WaitToRead, default is disabled
if (update_buf_wait_) update_buf->merged.WaitToRead();
CopyFromTo(update_buf->merged, &stored);

// async mode does not need this Copy
if (sync_mode_) CopyFromTo(update_buf->merged, &stored);

if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
update_buf->request.clear();
stored.WaitToRead();
Expand Down Expand Up @@ -782,7 +789,7 @@ class KVStoreDistServer {
if (has_multi_precision_copy(type)) {
CopyFromTo(recved, updates.temp_array);
} else {
updates.temp_array = recved;
stored += recved;
}
}
} else { // from other workers
Expand Down Expand Up @@ -861,6 +868,7 @@ class KVStoreDistServer {
bool multi_precision_;

bool update_buf_wait_;

/*
* send push response with the key as value
*/
Expand Down

0 comments on commit 9dfa1d8

Please sign in to comment.