Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix failing GPU test on single GPU host (kvstore) (#12726)
Browse files Browse the repository at this point in the history
Fixes #10977
  • Loading branch information
larroy authored and gigasquid committed Oct 6, 2018
1 parent c993ef1 commit 836ba78
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions tests/python/gpu/test_kvstore_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ def init_kv_with_str(stype='default', kv_type='local'):
# Test seed 89411477 (module seed 1829754103) resulted in a py3-gpu CI runner core dump.
# Not reproducible, so this test is back on random seeds.
@with_seed()
@unittest.skipIf(mx.context.num_gpus() < 2, "test_rsp_push_pull needs more than 1 GPU")
def test_rsp_push_pull():
def check_rsp_push_pull(kv_type, sparse_pull, is_push_cpu=True):
kv = init_kv_with_str('row_sparse', kv_type)
kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)]
kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs])

def check_rsp_pull(kv, count, ctxs, sparse_pull, is_same_rowid=False, use_slice=False):
def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False):
count = len(ctxs)
num_rows = shape[0]
row_ids = []
all_row_ids = np.arange(num_rows)
Expand Down Expand Up @@ -100,14 +102,14 @@ def check_rsp_pull(kv, count, ctxs, sparse_pull, is_same_rowid=False, use_slice=
expected_val[:] = 2
assert_almost_equal(retained, expected_val)

check_rsp_pull(kv, 1, [mx.gpu(0)], sparse_pull)
check_rsp_pull(kv, 1, [mx.cpu(0)], sparse_pull)
check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], sparse_pull)
check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], sparse_pull, is_same_rowid=True)
check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], sparse_pull)
check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], sparse_pull, is_same_rowid=True)
check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True)
check_rsp_pull(kv, 4, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True)
check_rsp_pull(kv, [mx.gpu(0)], sparse_pull)
check_rsp_pull(kv, [mx.cpu(0)], sparse_pull)
check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull)
check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, is_same_rowid=True)
check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull)
check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, is_same_rowid=True)
check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True)
check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True)

envs = ["","1"]
key = "MXNET_KVSTORE_USETREE"
Expand Down

0 comments on commit 836ba78

Please sign in to comment.