diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 65ded79743e4..1dc222c0d7da 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -364,21 +364,34 @@ class KVStoreDistServer { if (log_verbose_) { LOG(INFO) << "sent response to " << update_buf->request.size() << " workers"; } + /** + * Request can be for either push, pull or pushpull + * If pull flag is set, respond immediately with the updated values + * Otherwise, only send the notification + */ + bool has_pull = false; for (const auto& req : update_buf->request) { - /** - * Request can be for either push, pull or pushpull - * If pull flag is set, respond immediately with the updated values - * Otherwise, only send the notification - */ - if (req.pull) { - DefaultStorageResponse(type, key, req, req_data, server); - } else { + has_pull = has_pull || req.pull; + } + if (has_pull) { + // if there is a pull request, perform WaitToRead() once before DefaultStorageResponse + if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); + stored.WaitToRead(); + for (const auto& req : update_buf->request) { + if (req.pull) { + DefaultStorageResponse(type, key, req, req_data, server); + } + } + update_buf->request.clear(); + } else { + // otherwise, send response directly + for (const auto& req : update_buf->request) { server->Response(req); } + update_buf->request.clear(); + if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); + stored.WaitToRead(); } - update_buf->request.clear(); - if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); - stored.WaitToRead(); } else { update_buf->merged.WaitToRead(); } diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py index dc2c7bc35747..f3fe737f5653 100644 --- a/tests/nightly/dist_device_sync_kvstore.py +++ b/tests/nightly/dist_device_sync_kvstore.py @@ -44,7 +44,10 @@ def check_diff_to_scalar(A, x, rank=None): def init_kv(): # init kv dns keys kv.init(keys, [mx.nd.ones(shape)] * len(keys)) + kv.init('9', mx.nd.ones(shape)) + kv.init('10', mx.nd.ones(shape)) kv.init('99', mx.nd.ones(big_shape)) + kv.init('100', mx.nd.ones(big_shape)) # worker info my_rank = kv.rank nworker = kv.num_workers @@ -55,33 +58,30 @@ def init_kv(): def test_sync_push_pull(): kv, my_rank, nworker = init_kv() num_gpus = 2 - def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False): + def check_default_keys(kv, my_rank, nworker, nrepeat=3): # checks pull after push in loop, because behavior during # consecutive pushes doesn't offer any guarantees - for i in range(offset, nrepeat): + for i in range(nrepeat): scale = my_rank + 1 num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1 arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)] val = mx.nd.zeros(shape) - if use_pushpull: - kv.pushpull('3', arr, out=val) - else: - kv.push('3', arr) - kv.pull('3', out=val) + kv.push('9', arr) + kv.pull('9', out=val) + check_diff_to_scalar(val, num) + kv.pushpull('10', arr, out=val) check_diff_to_scalar(val, num) big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)] big_val = mx.nd.zeros(big_shape) - if use_pushpull: - kv.pushpull('99', big_arr, out=big_val) - else: - kv.push('99', big_arr) - kv.pull('99', out=big_val) + kv.push('99', big_arr) + kv.pull('99', out=big_val) + check_diff_to_scalar(big_val, num) + kv.pushpull('100', big_arr, out=big_val) check_diff_to_scalar(big_val, num) - check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False) - check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True) + check_default_keys(kv, my_rank, nworker, nrepeat=3) print('worker ' + str(my_rank) + ' is done') def test_sync_init(): @@ -106,10 +106,12 @@ def check_trainer_kv_update(update_on_kv): x = params.get('x', shape=(10,1), lr_mult=1.0) params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') try: - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv) + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, + kvstore=kv, update_on_kvstore=update_on_kv) trainer._init_kvstore() assert trainer._kv_initialized - assert trainer._update_on_kvstore is True + if update_on_kv is not None: + assert trainer._update_on_kvstore is update_on_kv except ValueError: assert update_on_kv is False @@ -122,3 +124,4 @@ def check_trainer_kv_update(update_on_kv): if __name__ == "__main__": test_sync_init() test_sync_push_pull() + test_gluon_trainer_type()