Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Backport #17007 to 1.6 #17052

Merged
merged 1 commit into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions src/kvstore/kvstore_dist_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,34 @@ class KVStoreDistServer {
if (log_verbose_) {
LOG(INFO) << "sent response to " << update_buf->request.size() << " workers";
}
/**
* Request can be for either push, pull or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
bool has_pull = false;
for (const auto& req : update_buf->request) {
/**
* Request can be for either push, pull or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
} else {
has_pull = has_pull || req.pull;
}
if (has_pull) {
// if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
stored.WaitToRead();
for (const auto& req : update_buf->request) {
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
}
}
update_buf->request.clear();
} else {
// otherwise, send response directly
for (const auto& req : update_buf->request) {
server->Response(req);
}
update_buf->request.clear();
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
stored.WaitToRead();
}
update_buf->request.clear();
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
stored.WaitToRead();
} else {
update_buf->merged.WaitToRead();
}
Expand Down
35 changes: 19 additions & 16 deletions tests/nightly/dist_device_sync_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def check_diff_to_scalar(A, x, rank=None):
def init_kv():
# init kv dns keys
kv.init(keys, [mx.nd.ones(shape)] * len(keys))
kv.init('9', mx.nd.ones(shape))
kv.init('10', mx.nd.ones(shape))
kv.init('99', mx.nd.ones(big_shape))
kv.init('100', mx.nd.ones(big_shape))
# worker info
my_rank = kv.rank
nworker = kv.num_workers
Expand All @@ -55,33 +58,30 @@ def init_kv():
def test_sync_push_pull():
kv, my_rank, nworker = init_kv()
num_gpus = 2
def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False):
def check_default_keys(kv, my_rank, nworker, nrepeat=3):
# checks pull after push in loop, because behavior during
# consecutive pushes doesn't offer any guarantees
for i in range(offset, nrepeat):
for i in range(nrepeat):
scale = my_rank + 1
num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1

arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
val = mx.nd.zeros(shape)
if use_pushpull:
kv.pushpull('3', arr, out=val)
else:
kv.push('3', arr)
kv.pull('3', out=val)
kv.push('9', arr)
kv.pull('9', out=val)
check_diff_to_scalar(val, num)
kv.pushpull('10', arr, out=val)
check_diff_to_scalar(val, num)

big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
big_val = mx.nd.zeros(big_shape)
if use_pushpull:
kv.pushpull('99', big_arr, out=big_val)
else:
kv.push('99', big_arr)
kv.pull('99', out=big_val)
kv.push('99', big_arr)
kv.pull('99', out=big_val)
check_diff_to_scalar(big_val, num)
kv.pushpull('100', big_arr, out=big_val)
check_diff_to_scalar(big_val, num)

check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False)
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True)
check_default_keys(kv, my_rank, nworker, nrepeat=3)
print('worker ' + str(my_rank) + ' is done')

def test_sync_init():
Expand All @@ -106,10 +106,12 @@ def check_trainer_kv_update(update_on_kv):
x = params.get('x', shape=(10,1), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
try:
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
kvstore=kv, update_on_kvstore=update_on_kv)
trainer._init_kvstore()
assert trainer._kv_initialized
assert trainer._update_on_kvstore is True
if update_on_kv is not None:
assert trainer._update_on_kvstore is update_on_kv
except ValueError:
assert update_on_kv is False

Expand All @@ -122,3 +124,4 @@ def check_trainer_kv_update(update_on_kv):
if __name__ == "__main__":
test_sync_init()
test_sync_push_pull()
test_gluon_trainer_type()