From 532d62fff15d339d765717d2d4671cd771e308cc Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 8 Jan 2020 18:14:58 +0000 Subject: [PATCH] adding asnumpy() to output of gather(implicitly called) to fix gather test in large vector and tensor tests --- tests/nightly/test_large_array.py | 4 ++-- tests/nightly/test_large_vector.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 2780b0d45d62..d2836f1807b1 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1672,10 +1672,10 @@ def test_gather(): idx = mx.nd.random.randint(0, LARGE_X, SMALL_X) # Calls gather_nd internally tmp = arr[idx] - assert np.sum(tmp[0] == 1) == SMALL_Y + assert np.sum(tmp[0].asnumpy() == 1) == SMALL_Y # Calls gather_nd internally arr[idx] += 1 - assert np.sum(arr[idx[0]] == 2) == SMALL_Y + assert np.sum(arr[idx[0]].asnumpy() == 2) == SMALL_Y if __name__ == '__main__': diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index c6a99a5d0826..bc0dedf269bf 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -1049,10 +1049,10 @@ def test_gather(): idx = mx.nd.random.randint(0, LARGE_X, 10, dtype=np.int64) # Calls gather_nd internally tmp = arr[idx] - assert np.sum(tmp == 1) == 10 + assert np.sum(tmp.asnumpy() == 1) == 10 # Calls gather_nd internally arr[idx] += 1 - assert np.sum(arr[idx] == 2) == 10 + assert np.sum(arr[idx].asnumpy() == 2) == 10 def test_infer_shape():