Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add PushPull test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Anand J committed Sep 2, 2019
1 parent 3cc6dfb commit 2eb6321
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions tests/nightly/dist_device_sync_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,33 @@ def init_kv():
def test_sync_push_pull():
kv, my_rank, nworker = init_kv()
num_gpus = 2
def check_default_keys(kv, my_rank, nworker):
nrepeat = 3
def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False):
# checks pull after push in loop, because behavior during
# consecutive pushes doesn't offer any guarantees
for i in range(nrepeat):
for i in range(offset, nrepeat):
scale = my_rank + 1
kv.push('3', [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)])
kv.push('99', [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)])
num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1

arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
val = mx.nd.zeros(shape)
kv.pull('3', out=val)
if use_pushpull:
kv.pushpull('3', arr, out=val)
else:
kv.push('3', arr)
kv.pull('3', out=val)
check_diff_to_scalar(val, num)
val2 = mx.nd.zeros(big_shape)
kv.pull('99', out=val2)
check_diff_to_scalar(val2, num)

check_default_keys(kv, my_rank, nworker)
big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
big_val = mx.nd.zeros(big_shape)
if use_pushpull:
kv.pushpull('99', big_arr, out=big_val)
else:
kv.push('99', big_arr)
kv.pull('99', out=big_val)
check_diff_to_scalar(big_val, num)

check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False)
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True)
print('worker ' + str(my_rank) + ' is done')

def test_sync_init():
Expand Down

0 comments on commit 2eb6321

Please sign in to comment.