From 6697d13fbe161f170b8275fcf474da44ce01126d Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 14 Aug 2019 22:39:11 +0000 Subject: [PATCH 1/4] Adding tests to verify support for Large Tensors in additional Ops along with new C_Apis supporting 64bit indexing --- tests/nightly/test_large_vector.py | 318 +++++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 3a66500957e0..7ea27c221abf 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -25,6 +25,7 @@ # dimension constants LARGE_X = 5000000000 MEDIUM_X = 1000000000 +SMALL_Y = 1 def test_slice(): @@ -33,6 +34,323 @@ def test_slice(): assert res.shape[0] == MEDIUM_X +def test_gluon_embedding(): + m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) + m.initialize() + a = nd.zeros((MEDIUM_X, SMALL_Y)) + b = m(a) + assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) + assert b.asnumpy().size == LARGE_SIZE + + +def test_ndarray_zeros(): + a = nd.zeros(shape=LARGE_X) + assert a[-1] == 0 + assert a.shape == (LARGE_X,) + assert a.size == LARGE_X + + +def test_ndarray_ones(): + a = nd.ones(shape=(LARGE_X)) + assert a[-1][0] == 1 + assert nd.sum(a).asnumpy() == LARGE_X + + +@unittest.skip("need to fix") +def test_ndarray_convert(): + a = nd.zeros(shape=(LARGE_X, SMALL_Y)) + b = a.astype(np.int32) + b.wait_to_read() + assert b.dtype == np.int32 + b = a.tostype('row_sparse') + b.wait_to_read() + assert isinstance(b, mx.nd.sparse.RowSparseNDArray) + + +@with_seed() +def test_ndarray_random_uniform(): + a = nd.random.uniform(shape=LARGE_X) + assert a[-1][0] != 0 + + +@with_seed() +def test_ndarray_random_randint(): + a = nd.random.randint(100, 10000, shape=LARGE_X) + assert a.shape == (LARGE_X,) + # check if randint can generate value greater than 2**32 (large) + low_large_value = 2**32 + high_large_value = 2**34 + a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64) + low = mx.nd.array([low_large_value], dtype='int64') + high = mx.nd.array([high_large_value], dtype='int64') + assert a.__gt__(low) and a.__lt__(high) + + +def test_ndarray_empty(): + a = nd.empty(LARGE_X) + assert a.shape == (LARGE_X,) + + +def test_elementwise(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) + res = a + b + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + res = a + 1 + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + res = nd.sqrt(a + 3) + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + + +def test_reduce(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] + + +@unittest.skip("need to fix") +def test_dot(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(SMALL_Y, SMALL_Y)) + res = nd.dot(a, b) + assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + + +def test_FullyConnected(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(SMALL_Y, SMALL_Y)) + res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True) + assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + + +def test_broadcast(): + a = nd.ones(shape=(LARGE_X, SMALL_Y*2)) + b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) + res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y*2)) + assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] + res = mx.nd.broadcast_like(b, a) + assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] + + +def test_clip(): + a = nd.arange(0, LARGE_X) + res = nd.clip(a, a_min=100, a_max=1000) + assert np.sum(res[-1].asnumpy() == 1000) == 101 + + +def test_argmin(): + a = nd.arange(0, LARGE_X) + idx = mx.nd.argmin(a, axis=0) + assert idx.shape[0] == SMALL_Y + + +def test_tile(): + a = nd.arange(0, LARGE_X) + b = nd.tile(a, reps=(1,2)) + assert b[0][LARGE_X] == b[0][0] + assert b[0][LARGE_X-1] == b[0][-1] + + +def test_take(): + a = nd.ones(shape=LARGE_X) + idx = nd.arange(LARGE_X - 1000, LARGE_X) + res = nd.take(a, idx) + assert np.sum(res.asnumpy() == 1) == res.shape[0] + + +def test_slice(): + a = nd.ones(shape=(2, LARGE_X)) + res = nd.slice(a, begin=(1, LARGE_X-1000000000), end=(2, LARGE_X)) + assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] + + +def test_slice_assign(): + a = nd.ones(shape=LARGE_X) + a[LARGE_X-1:LARGE_X] = 1000 + assert np.sum(a[-1].asnumpy() == 1000) == 1 + + +def test_expand_dims(): + a = nd.ones(shape=LARGE_X) + res = nd.expand_dims(a, axis=0) + assert res[0][0] == 1 + assert res.shape == (1, a.shape[0]) + + +def test_squeeze(): + a = nd.ones(shape=LARGE_X) + data = nd.expand_dims(a, axis=0) + res = nd.squeeze(data) + assert a[0] == res[0] + assert res.shape == a.shape + + +def test_broadcast_div(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) * 2 + res = a / b + assert np.sum(res.asnumpy() == 0.5) == a.shape[0] + + +def test_Dense(ctx=mx.cpu(0)): + data = mx.nd.ones(shape=LARGE_X) + linear = gluon.nn.Dense(2) + linear.initialize(ctx=ctx) + res = linear(data) + res.wait_to_read() + assert res.shape == (LARGE_X, 2) + + +@unittest.skip("need to fix") +def test_where(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y) + res = nd.where(b > 100, a, b) + assert np.sum(res[-1].asnumpy() == 1) == b.shape[1] + csr_cond = nd.sparse.cast_storage(b < 10, 'csr') + res = nd.sparse.where(csr_cond, a, b) + assert np.sum(res[0].asnumpy() == 1) == 10 + + +def test_pick(): + a = mx.nd.ones(shape=(LARGE_X, 2)) + b = mx.nd.ones(shape=LARGE_X) + res = mx.nd.pick(a, b) + assert res.shape == b.shape + + +def test_depthtospace(): + def numpy_depth_to_space(x, blocksize): + b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) + tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) + y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) + return y + + shape_inp = (LARGE_X, 4, 1, 1) + data = rand_ndarray(shape_inp, 'default') + data_np = data.asnumpy() + expected = numpy_depth_to_space(data_np, 2) + output = mx.nd.depth_to_space(data, 2) + assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + + +def test_spacetodepth(): + def numpy_space_to_depth(x, blocksize): + b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]) + tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) + y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize]) + return y + + shape_inp = (LARGE_X, 1, 2, 2) + data = rand_ndarray(shape_inp, 'default') + data_np = data.asnumpy() + expected = numpy_space_to_depth(data_np, 2) + output = mx.nd.space_to_depth(data, 2) + assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + +@with_seed() +def test_diag(): + a_np = np.random.random((LARGE_X, 2)).astype(np.float32) + a = mx.nd.array(a_np) + + # k == 0 + r = mx.nd.diag(a) + assert_almost_equal(r.asnumpy(), np.diag(a_np)) + + # k == 1 + k = 1 + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + # k == -1 + k = -1 + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + +@with_seed() +def test_ravel_multi_index(): + x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, SMALL_Y, 4) + x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, SMALL_Y, 3) + x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, SMALL_Y, 2) + indices_2d = [[x1, x2, x3], [y1, y2, y3]] + idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, 5)) + idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, 5)) + assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3 + + +@with_seed() +def test_unravel_index(): + x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, SMALL_Y, 4) + x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, SMALL_Y, 3) + x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, SMALL_Y, 2) + original_2d_indices = [[x1, x2, x3], [y1, y2, y3]] + idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, 5)) + indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, 5)) + assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() + + +def create_large_vector(size, dtype=np.int64): + a = nd.arange(0, size, dtype=dtype) + # Implicitly calling nd.waitall() + assert a[0] == 0 + return a + + +def test_transpose(): + b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) + t = b.T + assert t.shape == (LARGE_X, 1) + assert t[-1, 0].asnumpy() == (LARGE_X - 1) + + +def test_swapaxes(): + b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(LARGE_X, 1) + t = nd.swapaxes(b, dim1=0, dim2=1) + assert t.shape == (1, LARGE_X) + assert t[0, -1].asnumpy() == (LARGE_X - 1) + + +def test_flip(): + b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) + t = nd.flip(b, axis=0) + assert t.shape == (LARGE_X, 1) + assert t[-1, :].asnumpy() == 0 + + +def test_softmax(): + input_data = mx.nd.ones(2, LARGE_X) + true_output = np.full(LARGE_X, 0.5) + output = nd.softmax(input_data, axis=0) + assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) + + +def test_argsort(): + b = create_large_vector(size=LARGE_X) + s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) + mx.nd.waitall() + assert (s[0].asnumpy() == (LARGE_X - 1)).all() + + +def test_sort(): + b = create_large_vector(size=LARGE_X) + s = nd.sort(b, axis=0, is_ascend=False) + assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all() + s = nd.sort(b, is_ascend=True) + assert np.sum(s[0].asnumpy() == 0).all() + + +def test_topk(): + b = create_large_vector(size=LARGE_X) + k = nd.topk(b, k=10, axis=0, dtype=np.int64) + assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y + ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) + assert np.all(ind == val) + l = nd.topk(b, k=1, axis=0, dtype=np.int64, ret_typ="value") + assert l.sum() == (LARGE_X - 1) + + if __name__ == '__main__': import nose nose.runmodule() From f568c3efd0fbdb5235ed074bfa04c044c29f9459 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 21 Aug 2019 02:47:53 +0000 Subject: [PATCH 2/4] removing skipped tests --- tests/nightly/test_large_vector.py | 45 ++++++------------------------ 1 file changed, 8 insertions(+), 37 deletions(-) diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 7ea27c221abf..b45d51d9f1fa 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -25,6 +25,7 @@ # dimension constants LARGE_X = 5000000000 MEDIUM_X = 1000000000 +LARGE_Y = 100000 SMALL_Y = 1 @@ -35,12 +36,12 @@ def test_slice(): def test_gluon_embedding(): - m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) + m = gluon.nn.Embedding(1, LARGE_Y) m.initialize() - a = nd.zeros((MEDIUM_X, SMALL_Y)) + a = nd.zeros((LARGE_Y, 1)) b = m(a) - assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) - assert b.asnumpy().size == LARGE_SIZE + assert b.shape == (LARGE_Y, 1, LARGE_Y) + assert b.asnumpy().size == LARGE_X*2 def test_ndarray_zeros(): @@ -51,26 +52,15 @@ def test_ndarray_zeros(): def test_ndarray_ones(): - a = nd.ones(shape=(LARGE_X)) - assert a[-1][0] == 1 + a = nd.ones(shape=LARGE_X) + assert a[-1] == 1 assert nd.sum(a).asnumpy() == LARGE_X -@unittest.skip("need to fix") -def test_ndarray_convert(): - a = nd.zeros(shape=(LARGE_X, SMALL_Y)) - b = a.astype(np.int32) - b.wait_to_read() - assert b.dtype == np.int32 - b = a.tostype('row_sparse') - b.wait_to_read() - assert isinstance(b, mx.nd.sparse.RowSparseNDArray) - - @with_seed() def test_ndarray_random_uniform(): a = nd.random.uniform(shape=LARGE_X) - assert a[-1][0] != 0 + assert a[-1] != 0 @with_seed() @@ -107,14 +97,6 @@ def test_reduce(): assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] -@unittest.skip("need to fix") -def test_dot(): - a = nd.ones(shape=(LARGE_X, SMALL_Y)) - b = nd.ones(shape=(SMALL_Y, SMALL_Y)) - res = nd.dot(a, b) - assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] - - def test_FullyConnected(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.ones(shape=(SMALL_Y, SMALL_Y)) @@ -200,17 +182,6 @@ def test_Dense(ctx=mx.cpu(0)): assert res.shape == (LARGE_X, 2) -@unittest.skip("need to fix") -def test_where(): - a = nd.ones(shape=(LARGE_X, SMALL_Y)) - b = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y) - res = nd.where(b > 100, a, b) - assert np.sum(res[-1].asnumpy() == 1) == b.shape[1] - csr_cond = nd.sparse.cast_storage(b < 10, 'csr') - res = nd.sparse.where(csr_cond, a, b) - assert np.sum(res[0].asnumpy() == 1) == 10 - - def test_pick(): a = mx.nd.ones(shape=(LARGE_X, 2)) b = mx.nd.ones(shape=LARGE_X) From 12e77366a6b6a91abd4440d47b76e61d1ce1a63e Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Thu, 22 Aug 2019 06:49:37 +0000 Subject: [PATCH 3/4] enabling Large Index support for slice and softmax --- python/mxnet/test_utils.py | 13 ++++++++ src/operator/softmax_output-inl.h | 18 +++++------ src/operator/tensor/matrix_op-inl.h | 6 ++-- tests/nightly/test_large_array.py | 8 +---- tests/nightly/test_large_vector.py | 50 +++++++++-------------------- 5 files changed, 41 insertions(+), 54 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index c326091dbd9f..30d78d2e1593 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -262,6 +262,19 @@ def assign_each2(input1, input2, function): return output +# For testing Large Tensors having total size > 2^32 elements +def create_2d_tensor(rows, columns, dtype=np.int64): + a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1) + b = nd.broadcast_to(a, shape=(a.shape[0], columns)) + return nd.array(b, dtype=dtype) + +# For testing Large Vectors having total size > 2^32 elements +def create_vector(size, dtype=np.int64): + a = nd.arange(0, size, dtype=dtype) + # Implicitly calling nd.waitall() + assert a[0] == 0 + return a + def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, data_init=None, rsp_indices=None, modifier_func=None, shuffle_csr_indices=False, ctx=None): diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 80ab40ef6c50..db8676c028e4 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -117,9 +117,9 @@ class SoftmaxOutputOp : public Operator { CHECK_EQ(out_data.size(), 1U) << "SoftmaxOutput Output: [output]"; Stream *s = ctx.get_stream(); if (param_.multi_output) { - int n = in_data[softmaxout_enum::kData].size(0); - int k = in_data[softmaxout_enum::kData].size(1); - Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); + index_t n = in_data[softmaxout_enum::kData].size(0); + index_t k = in_data[softmaxout_enum::kData].size(1); + Shape<3> s3 = Shape3(n, k, static_cast(in_data[softmaxout_enum::kData].Size()/n/k)); Tensor data = in_data[softmaxout_enum::kData].get_with_shape(s3, s); Tensor out = @@ -131,8 +131,8 @@ class SoftmaxOutputOp : public Operator { Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); Softmax(out, data); } else { - int n = in_data[softmaxout_enum::kData].size(0); - int k = in_data[softmaxout_enum::kData].Size()/n; + index_t n = in_data[softmaxout_enum::kData].size(0); + index_t k = in_data[softmaxout_enum::kData].Size()/n; Shape<2> s2 = Shape2(n, k); Tensor data = in_data[softmaxout_enum::kData].get_with_shape(s2, s); @@ -171,9 +171,9 @@ class SoftmaxOutputOp : public Operator { grad = (out - label) * scalar(param_.grad_scale); } } else if (param_.multi_output) { - int n = out_data[softmaxout_enum::kOut].size(0); - int k = out_data[softmaxout_enum::kOut].size(1); - Shape<3> s3 = Shape3(n, k, static_cast(out_data[softmaxout_enum::kOut].Size()/n/k)); + index_t n = out_data[softmaxout_enum::kOut].size(0); + index_t k = out_data[softmaxout_enum::kOut].size(1); + Shape<3> s3 = Shape3(n, k, static_cast(out_data[softmaxout_enum::kOut].Size()/n/k)); Shape<2> s2 = Shape2(s3[0], s3[2]); Tensor label = in_data[softmaxout_enum::kLabel].get_with_shape(s2, s); @@ -224,7 +224,7 @@ class SoftmaxOutputOp : public Operator { // Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); // Tensor grad = in_grad[softmaxout_enum::kData].FlatTo2D(s); } else { - int n = out_data[softmaxout_enum::kOut].size(0); + index_t n = out_data[softmaxout_enum::kOut].size(0); data_shape = Shape2(n, out_data[softmaxout_enum::kOut].Size()/n); } Tensor label = in_data[softmaxout_enum::kLabel].get_with_shape( diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 611dd7287206..58a535353e10 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -732,8 +732,8 @@ inline void GetIndexRange(const mxnet::TShape& dshape, } inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, - const index_t i, const int b, - const int e, const int s, + const index_t i, const index_t b, + const index_t e, const index_t s, mxnet::TShape* oshape) { if (!mxnet::dim_size_is_known(dshape, i)) { (*oshape)[i] = -1; @@ -765,7 +765,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { - const int b = begin[i], e = end[i], s = step[i]; + const index_t b = begin[i], e = end[i], s = step[i]; SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape); } }) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index bd452fb75f6c..cdacce91ab6e 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -19,7 +19,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -31,12 +31,6 @@ LARGE_SIZE = LARGE_X * SMALL_Y -def create_2d_tensor(rows, columns, dtype=np.int64): - a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1) - b = nd.broadcast_to(a, shape=(a.shape[0], columns)) - return nd.array(b, dtype=dtype) - - def test_gluon_embedding(): m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) m.initialize() diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index b45d51d9f1fa..779afd5cb9b2 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -18,7 +18,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, create_vector from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -85,11 +85,11 @@ def test_elementwise(): a = nd.ones(shape=LARGE_X) b = nd.ones(shape=LARGE_X) res = a + b - assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + assert res[-1].asnumpy() == 2 res = a + 1 - assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] - res = nd.sqrt(a + 3) - assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + assert res[-1].asnumpy() == 2 + res = nd.sqrt(a + 8) + assert res[-1].asnumpy() == 3 def test_reduce(): @@ -97,13 +97,6 @@ def test_reduce(): assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] -def test_FullyConnected(): - a = nd.ones(shape=(LARGE_X, SMALL_Y)) - b = nd.ones(shape=(SMALL_Y, SMALL_Y)) - res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True) - assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] - - def test_broadcast(): a = nd.ones(shape=(LARGE_X, SMALL_Y*2)) b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) @@ -116,7 +109,7 @@ def test_broadcast(): def test_clip(): a = nd.arange(0, LARGE_X) res = nd.clip(a, a_min=100, a_max=1000) - assert np.sum(res[-1].asnumpy() == 1000) == 101 + assert np.sum(res[-1].asnumpy() == 1000) == 1 def test_argmin(): @@ -139,12 +132,6 @@ def test_take(): assert np.sum(res.asnumpy() == 1) == res.shape[0] -def test_slice(): - a = nd.ones(shape=(2, LARGE_X)) - res = nd.slice(a, begin=(1, LARGE_X-1000000000), end=(2, LARGE_X)) - assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] - - def test_slice_assign(): a = nd.ones(shape=LARGE_X) a[LARGE_X-1:LARGE_X] = 1000 @@ -262,13 +249,6 @@ def test_unravel_index(): assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() -def create_large_vector(size, dtype=np.int64): - a = nd.arange(0, size, dtype=dtype) - # Implicitly calling nd.waitall() - assert a[0] == 0 - return a - - def test_transpose(): b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) t = b.T @@ -285,27 +265,27 @@ def test_swapaxes(): def test_flip(): b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) - t = nd.flip(b, axis=0) - assert t.shape == (LARGE_X, 1) - assert t[-1, :].asnumpy() == 0 + t = nd.flip(b, axis=1) + assert t.shape == (1, LARGE_X) + assert t[-1, -1].asnumpy() == 0 def test_softmax(): - input_data = mx.nd.ones(2, LARGE_X) - true_output = np.full(LARGE_X, 0.5) + input_data = nd.ones((2, LARGE_X)) output = nd.softmax(input_data, axis=0) - assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) + assert output[0][0] == 0.5 + assert output[-1][-1] == 0.5 def test_argsort(): - b = create_large_vector(size=LARGE_X) + b = create_vector(size=LARGE_X) s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) mx.nd.waitall() assert (s[0].asnumpy() == (LARGE_X - 1)).all() def test_sort(): - b = create_large_vector(size=LARGE_X) + b = create_vector(size=LARGE_X) s = nd.sort(b, axis=0, is_ascend=False) assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all() s = nd.sort(b, is_ascend=True) @@ -313,7 +293,7 @@ def test_sort(): def test_topk(): - b = create_large_vector(size=LARGE_X) + b = create_vector(size=LARGE_X) k = nd.topk(b, k=10, axis=0, dtype=np.int64) assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) From fb9cdb1153a3e15d1caf506aeea16c3a86d12646 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Fri, 23 Aug 2019 22:04:12 +0000 Subject: [PATCH 4/4] removing tests not required for vector testing --- python/mxnet/test_utils.py | 10 +- tests/nightly/test_large_vector.py | 160 +++-------------------------- 2 files changed, 18 insertions(+), 152 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 30d78d2e1593..bb730fd3a007 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -264,15 +264,13 @@ def assign_each2(input1, input2, function): # For testing Large Tensors having total size > 2^32 elements def create_2d_tensor(rows, columns, dtype=np.int64): - a = nd.arange(0, rows, dtype=dtype).reshape(rows, 1) - b = nd.broadcast_to(a, shape=(a.shape[0], columns)) - return nd.array(b, dtype=dtype) + a = mx.nd.arange(0, rows, dtype=dtype).reshape(rows, 1) + b = mx.nd.broadcast_to(a, shape=(a.shape[0], columns)) + return b # For testing Large Vectors having total size > 2^32 elements def create_vector(size, dtype=np.int64): - a = nd.arange(0, size, dtype=dtype) - # Implicitly calling nd.waitall() - assert a[0] == 0 + a = mx.nd.arange(0, size, dtype=dtype) return a def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 779afd5cb9b2..64bfa8a1d3e9 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -18,32 +18,22 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, create_vector +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector from mxnet import gluon, nd from tests.python.unittest.common import with_seed # dimension constants LARGE_X = 5000000000 MEDIUM_X = 1000000000 -LARGE_Y = 100000 -SMALL_Y = 1 def test_slice(): a = nd.ones(LARGE_X) res = nd.slice(a, begin=(LARGE_X - MEDIUM_X), end=LARGE_X) + assert a[0] == 1 assert res.shape[0] == MEDIUM_X -def test_gluon_embedding(): - m = gluon.nn.Embedding(1, LARGE_Y) - m.initialize() - a = nd.zeros((LARGE_Y, 1)) - b = m(a) - assert b.shape == (LARGE_Y, 1, LARGE_Y) - assert b.asnumpy().size == LARGE_X*2 - - def test_ndarray_zeros(): a = nd.zeros(shape=LARGE_X) assert a[-1] == 0 @@ -73,7 +63,7 @@ def test_ndarray_random_randint(): a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64) low = mx.nd.array([low_large_value], dtype='int64') high = mx.nd.array([high_large_value], dtype='int64') - assert a.__gt__(low) and a.__lt__(high) + assert a > low and a < high def test_ndarray_empty(): @@ -93,36 +83,22 @@ def test_elementwise(): def test_reduce(): - a = nd.ones(shape=(LARGE_X, SMALL_Y)) + a = nd.ones(shape=(LARGE_X, 1)) assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] -def test_broadcast(): - a = nd.ones(shape=(LARGE_X, SMALL_Y*2)) - b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) - res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y*2)) - assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] - res = mx.nd.broadcast_like(b, a) - assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] - - def test_clip(): - a = nd.arange(0, LARGE_X) + a = create_vector(LARGE_X) res = nd.clip(a, a_min=100, a_max=1000) assert np.sum(res[-1].asnumpy() == 1000) == 1 def test_argmin(): - a = nd.arange(0, LARGE_X) + a = create_vector(LARGE_X, dtype=np.float32) + assert a[0] == 0 idx = mx.nd.argmin(a, axis=0) - assert idx.shape[0] == SMALL_Y - - -def test_tile(): - a = nd.arange(0, LARGE_X) - b = nd.tile(a, reps=(1,2)) - assert b[0][LARGE_X] == b[0][0] - assert b[0][LARGE_X-1] == b[0][-1] + assert idx[0] == 0 + assert idx.shape[0] == 1 def test_take(): @@ -169,114 +145,6 @@ def test_Dense(ctx=mx.cpu(0)): assert res.shape == (LARGE_X, 2) -def test_pick(): - a = mx.nd.ones(shape=(LARGE_X, 2)) - b = mx.nd.ones(shape=LARGE_X) - res = mx.nd.pick(a, b) - assert res.shape == b.shape - - -def test_depthtospace(): - def numpy_depth_to_space(x, blocksize): - b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] - tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) - tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) - y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) - return y - - shape_inp = (LARGE_X, 4, 1, 1) - data = rand_ndarray(shape_inp, 'default') - data_np = data.asnumpy() - expected = numpy_depth_to_space(data_np, 2) - output = mx.nd.depth_to_space(data, 2) - assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) - - -def test_spacetodepth(): - def numpy_space_to_depth(x, blocksize): - b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] - tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]) - tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) - y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize]) - return y - - shape_inp = (LARGE_X, 1, 2, 2) - data = rand_ndarray(shape_inp, 'default') - data_np = data.asnumpy() - expected = numpy_space_to_depth(data_np, 2) - output = mx.nd.space_to_depth(data, 2) - assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) - -@with_seed() -def test_diag(): - a_np = np.random.random((LARGE_X, 2)).astype(np.float32) - a = mx.nd.array(a_np) - - # k == 0 - r = mx.nd.diag(a) - assert_almost_equal(r.asnumpy(), np.diag(a_np)) - - # k == 1 - k = 1 - r = mx.nd.diag(a, k=k) - assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) - - # k == -1 - k = -1 - r = mx.nd.diag(a, k=k) - assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) - - -@with_seed() -def test_ravel_multi_index(): - x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, SMALL_Y, 4) - x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, SMALL_Y, 3) - x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, SMALL_Y, 2) - indices_2d = [[x1, x2, x3], [y1, y2, y3]] - idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, 5)) - idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, 5)) - assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3 - - -@with_seed() -def test_unravel_index(): - x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, SMALL_Y, 4) - x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, SMALL_Y, 3) - x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, SMALL_Y, 2) - original_2d_indices = [[x1, x2, x3], [y1, y2, y3]] - idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, 5)) - indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, 5)) - assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() - - -def test_transpose(): - b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) - t = b.T - assert t.shape == (LARGE_X, 1) - assert t[-1, 0].asnumpy() == (LARGE_X - 1) - - -def test_swapaxes(): - b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(LARGE_X, 1) - t = nd.swapaxes(b, dim1=0, dim2=1) - assert t.shape == (1, LARGE_X) - assert t[0, -1].asnumpy() == (LARGE_X - 1) - - -def test_flip(): - b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) - t = nd.flip(b, axis=1) - assert t.shape == (1, LARGE_X) - assert t[-1, -1].asnumpy() == 0 - - -def test_softmax(): - input_data = nd.ones((2, LARGE_X)) - output = nd.softmax(input_data, axis=0) - assert output[0][0] == 0.5 - assert output[-1][-1] == 0.5 - - def test_argsort(): b = create_vector(size=LARGE_X) s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) @@ -287,19 +155,19 @@ def test_argsort(): def test_sort(): b = create_vector(size=LARGE_X) s = nd.sort(b, axis=0, is_ascend=False) - assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all() + assert np.sum(s[-1].asnumpy() == 0).all() s = nd.sort(b, is_ascend=True) assert np.sum(s[0].asnumpy() == 0).all() def test_topk(): b = create_vector(size=LARGE_X) - k = nd.topk(b, k=10, axis=0, dtype=np.int64) - assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y + ind = nd.topk(b, k=10, axis=0, dtype=np.int64) + assert np.sum(ind.asnumpy() == (LARGE_X - 1)) == 1 ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) assert np.all(ind == val) - l = nd.topk(b, k=1, axis=0, dtype=np.int64, ret_typ="value") - assert l.sum() == (LARGE_X - 1) + val = nd.topk(b, k=1, axis=0, dtype=np.int64, ret_typ="value") + assert val.sum() == (LARGE_X - 1) if __name__ == '__main__':