From 9dfa1d8e2e32a0543322c26b287ef06d385f3520 Mon Sep 17 00:00:00 2001 From: Yimin Jiang Date: Sun, 22 Sep 2019 08:02:41 +0800 Subject: [PATCH] add async training support (#2) --- src/kvstore/kvstore_dist_server.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index daf62840ccc5..801fc4d4e64b 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -175,10 +175,14 @@ class KVStoreDistServer { std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2)); ps_server_->set_request_handle( std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); - sync_mode_ = dmlc::GetEnv("IS_PS_SYNC_MODE", true); gradient_compression_ = std::make_shared(); log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); update_buf_wait_ = dmlc::GetEnv("PS_ENABLE_GRADIENT_WAIT", false); + + sync_mode_ = !dmlc::GetEnv("BYTEPS_ENABLE_ASYNC", false); + if (!sync_mode_) { + LOG(INFO) << "BytePS server is enabled asynchronous training"; + } } ~KVStoreDistServer() { @@ -218,7 +222,7 @@ class KVStoreDistServer { exec_.Stop(); break; case CommandType::kSyncMode: - sync_mode_ = true; + CHECK(0) << "kSyncMode is not available now"; break; case CommandType::kSetGradientCompression: gradient_compression_->DecodeParams(recved.body); @@ -365,7 +369,10 @@ class KVStoreDistServer { auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array; // NOTE: not sure whether we need this WaitToRead, default is disabled if (update_buf_wait_) update_buf->merged.WaitToRead(); - CopyFromTo(update_buf->merged, &stored); + + // async mode does not need this Copy + if (sync_mode_) CopyFromTo(update_buf->merged, &stored); + if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); update_buf->request.clear(); stored.WaitToRead(); @@ -782,7 +789,7 @@ class KVStoreDistServer { if (has_multi_precision_copy(type)) { CopyFromTo(recved, updates.temp_array); } else { - updates.temp_array = recved; + stored += recved; } } } else { // from other workers @@ -861,6 +868,7 @@ class KVStoreDistServer { bool multi_precision_; bool update_buf_wait_; + /* * send push response with the key as value */