From 1eb499556a54498e81ebc614ae54ba74e406d24a Mon Sep 17 00:00:00 2001 From: sguangyo <1360024032@qq.com> Date: Mon, 2 Sep 2019 12:38:37 +0800 Subject: [PATCH] fix test_pick test time is too long --- tests/python/unittest/test_operator.py | 69 +++++++++++++------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 39ae0a02f0fe..de02727e24c8 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5260,42 +5260,41 @@ def np_softmax_with_length(data, length): @with_seed() def test_pick(): def test_pick_helper(index_type=np.int32): - for _ in range(100): - for mode in ['clip', 'wrap']: - ndim = np.random.randint(1, 5) - bshape = np.random.randint(1, 10, size=ndim) - axis = np.random.randint(0, ndim) - sshape = bshape.copy() - sshape[axis] = 1 - data = np.random.uniform(-1, 1, size=bshape) - - if mode == 'wrap': - index = np.random.randint(-2*bshape[axis], 2*bshape[axis], size=sshape) - else: - index = np.random.randint(0, bshape[axis], size=sshape) - exp = [] - for i in range(ndim): - if i == axis: - if mode == 'wrap': - exp.append(index % bshape[axis]) - else: - exp.append(index) + for mode in ['clip', 'wrap']: + ndim = np.random.randint(1, 5) + bshape = np.random.randint(1, 10, size=ndim) + axis = np.random.randint(0, ndim) + sshape = bshape.copy() + sshape[axis] = 1 + data = np.random.uniform(-1, 1, size=bshape) + + if mode == 'wrap': + index = np.random.randint(-2*bshape[axis], 2*bshape[axis], size=sshape) + else: + index = np.random.randint(0, bshape[axis], size=sshape) + exp = [] + for i in range(ndim): + if i == axis: + if mode == 'wrap': + exp.append(index % bshape[axis]) else: - ishape = [1 for _ in range(ndim)] - ishape[i] = bshape[i] - exp.append(np.arange(bshape[i]).reshape(ishape)) - expected = data[exp] - data = mx.nd.array(data, dtype='float32') - index = mx.nd.array(index, dtype=index_type) - out = mx.nd.pick(data, index, axis=axis, keepdims=True, mode=mode) - assert_almost_equal(out.asnumpy(), expected) - - data_holder = data - index_holder = index - data = mx.sym.Variable('data') - index = mx.sym.Variable('index') - sym = mx.sym.pick(data, index, axis=axis, keepdims=True, mode=mode) - check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data']) + exp.append(index) + else: + ishape = [1 for _ in range(ndim)] + ishape[i] = bshape[i] + exp.append(np.arange(bshape[i]).reshape(ishape)) + expected = data[exp] + data = mx.nd.array(data, dtype='float32') + index = mx.nd.array(index, dtype=index_type) + out = mx.nd.pick(data, index, axis=axis, keepdims=True, mode=mode) + assert_almost_equal(out.asnumpy(), expected) + + data_holder = data + index_holder = index + data = mx.sym.Variable('data') + index = mx.sym.Variable('index') + sym = mx.sym.pick(data, index, axis=axis, keepdims=True, mode=mode) + check_numeric_gradient(sym, [data_holder, index_holder], grad_nodes=['data']) test_pick_helper(np.int32) test_pick_helper(np.float32)