From 8d4ac0e88df43a3eefdd0b86fe5f5f53df48f36a Mon Sep 17 00:00:00 2001 From: Anand J Date: Mon, 26 Aug 2019 13:39:45 -0700 Subject: [PATCH] Add PushPull test cases --- tests/nightly/dist_device_sync_kvstore.py | 30 +++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/nightly/dist_device_sync_kvstore.py b/tests/nightly/dist_device_sync_kvstore.py index 7fd0333aea79..dc2c7bc35747 100644 --- a/tests/nightly/dist_device_sync_kvstore.py +++ b/tests/nightly/dist_device_sync_kvstore.py @@ -55,23 +55,33 @@ def init_kv(): def test_sync_push_pull(): kv, my_rank, nworker = init_kv() num_gpus = 2 - def check_default_keys(kv, my_rank, nworker): - nrepeat = 3 + def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False): # checks pull after push in loop, because behavior during # consecutive pushes doesn't offer any guarantees - for i in range(nrepeat): + for i in range(offset, nrepeat): scale = my_rank + 1 - kv.push('3', [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]) - kv.push('99', [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]) num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1 + + arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)] val = mx.nd.zeros(shape) - kv.pull('3', out=val) + if use_pushpull: + kv.pushpull('3', arr, out=val) + else: + kv.push('3', arr) + kv.pull('3', out=val) check_diff_to_scalar(val, num) - val2 = mx.nd.zeros(big_shape) - kv.pull('99', out=val2) - check_diff_to_scalar(val2, num) - check_default_keys(kv, my_rank, nworker) + big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)] + big_val = mx.nd.zeros(big_shape) + if use_pushpull: + kv.pushpull('99', big_arr, out=big_val) + else: + kv.push('99', big_arr) + kv.pull('99', out=big_val) + check_diff_to_scalar(big_val, num) + + check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False) + check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True) print('worker ' + str(my_rank) + ' is done') def test_sync_init():