From 99f15e7b7a5dd28f483cf56d40a22e41f8a0c252 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 14 Aug 2019 22:39:11 +0000 Subject: [PATCH] Adding tests to verify support for Large Tensors in additional Ops along with new C_Apis supporting 64bit indexing --- include/mxnet/c_api.h | 10 + python/mxnet/ndarray/ndarray.py | 34 ++- src/c_api/c_api.cc | 33 ++- src/operator/tensor/dot-inl.h | 2 +- tests/nightly/test_large_array.py | 316 +++++++++++++++++++++++++++- tests/nightly/test_large_vector.py | 318 +++++++++++++++++++++++++++++ 6 files changed, 697 insertions(+), 16 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 5ab10b6b2204..8b110a43061a 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -734,6 +734,7 @@ MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, void *data, size_t size); + /*! * \brief Copy src.data() to dst.data() if i = -1, else dst.aux_data(i) if i >= 0 * This function blocks. Do not use it in performance critical code. @@ -790,6 +791,11 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_end, NDArrayHandle *out); +MXNET_DLL int MXNDArraySlice64(NDArrayHandle handle, + int64_t slice_begin, + int64_t slice_end, + NDArrayHandle *out); + /*! * \brief Index the NDArray along axis 0. * \param handle the handle to the NDArray @@ -801,6 +807,10 @@ MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); +MXNET_DLL int MXNDArrayAt64(NDArrayHandle handle, + int64_t idx, + NDArrayHandle *out); + /*! * \brief get the storage type of the array */ diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 5f03c65a2e79..b7573650d64d 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -932,14 +932,24 @@ def _get_nd_basic_indexing(self, key): ) handle = NDArrayHandle() flat_self = self.reshape(-1) - check_call( - _LIB.MXNDArraySlice( - flat_self.handle, - mx_uint(flat_begin), - mx_uint(flat_end), - ctypes.byref(handle), + if sys.version_info[0] > 2 and _int64_enabled(): + check_call( + _LIB.MXNDArraySlice64( + flat_self.handle, + ctypes.c_int64(flat_begin), + ctypes.c_int64(flat_end), + ctypes.byref(handle), + ) + ) + else: + check_call( + _LIB.MXNDArraySlice( + flat_self.handle, + ctypes.c_uint32(flat_begin), + ctypes.c_uint32(flat_end), + ctypes.byref(handle), + ) ) - ) sliced_shape = self._basic_indexing_sliced_shape(slc_key, self.shape) sliced = NDArray(handle=handle, writable=self.writable).reshape(sliced_shape) else: @@ -1235,9 +1245,13 @@ def _at(self, idx): if idx < 0: raise IndexError('index %d is out of bounds for axis 0 with size %d' % (idx-length, length)) - check_call(_LIB.MXNDArrayAt( - self.handle, mx_uint(idx), ctypes.byref(handle))) - return self.__class__(handle=handle, writable=self.writable) + if sys.version_info[0] > 2 and _int64_enabled(): + check_call(_LIB.MXNDArrayAt64( + self.handle, ctypes.c_int64(idx), ctypes.byref(handle))) + else: + check_call(_LIB.MXNDArrayAt( + self.handle, ctypes.c_uint32(idx), ctypes.byref(handle))) + return NDArray(handle=handle, writable=self.writable) def reshape(self, *shape, **kwargs): """Returns a **view** of this array with a new shape without altering any data. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index c2b80b3f601c..f6b67d3cb437 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -451,20 +451,35 @@ int MXNDArrayFree(NDArrayHandle handle) { API_END(); } +template +void SliceArray(NDArrayHandle handle, dtype slice_begin, dtype slice_end, NDArray* ptr, + NDArrayHandle* out) { + *ptr = static_cast(handle)->SliceWithRecord(slice_begin, slice_end); + *out = ptr; +} + int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out) { NDArray *ptr = new NDArray(); API_BEGIN(); - *ptr = static_cast(handle)->SliceWithRecord( - slice_begin, slice_end); - *out = ptr; + SliceArray(handle, slice_begin, slice_end, ptr, out); + API_END_HANDLE_ERROR(delete ptr); +} + +int MXNDArraySlice64(NDArrayHandle handle, + int64_t slice_begin, + int64_t slice_end, + NDArrayHandle *out) { + NDArray *ptr = new NDArray(); + API_BEGIN(); + SliceArray(handle, slice_begin, slice_end, ptr, out); API_END_HANDLE_ERROR(delete ptr); } int MXNDArrayAt(NDArrayHandle handle, - mx_uint idx, + uint32_t idx, NDArrayHandle *out) { NDArray *ptr = new NDArray(); API_BEGIN(); @@ -473,6 +488,16 @@ int MXNDArrayAt(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArrayAt64(NDArrayHandle handle, + int64_t idx, + NDArrayHandle *out) { + NDArray *ptr = new NDArray(); + API_BEGIN(); + *ptr = static_cast(handle)->AtWithRecord(idx); + *out = ptr; + API_END_HANDLE_ERROR(delete ptr); +} + MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, int ndim, int *dims, diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 77e8e36bbef8..96c869f40d40 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -91,7 +91,7 @@ void DotForward_(const nnvm::NodeAttrs& attrs, inputs[0].get(s), inputs[1].get(s)); } else { - int ma, na, mb, nb, m, n; + index_t ma, na, mb, nb, m, n; if (param.transpose_a) { ma = inputs[0].size(0); na = inputs[0].Size()/ma; diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 02c867720609..bee2e9b812cf 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -19,7 +19,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -891,6 +891,320 @@ def test_rpow(): assert c.shape == a.shape +def test_shape(): + b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + mx.nd.waitall() + assert b.shape == (SMALL_Y, LARGE_X) + + +def test_size(): + b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + mx.nd.waitall() + assert b.size == LARGE_SIZE + + +def test_copy(): + a = nd.ones((SMALL_Y, LARGE_X)) + b = a.copy() + nd.waitall() + assert b.shape == a.shape + assert b.size == LARGE_SIZE + + +def test_copy_to(): + a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + b = nd.array(np.zeros((SMALL_Y, LARGE_X))) + c = a.copyto(b) + assert c is b + print(b) + assert b[0][-1] == LARGE_X-1 + + +def test_zeros_like(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.zeros_like(a) + assert b[-1][-1] == 0 + assert b.shape == a.shape + + +def test_ones_like(): + a = nd.array(np.zeros((SMALL_Y, LARGE_X))) + b = nd.ones_like(a) + assert b[-1][-1] == 1 + assert b.shape == a.shape + + +def test_reshape_like(): + a = nd.array(np.zeros((SMALL_Y, LARGE_X))) + b = nd.array(np.zeros((SMALL_Y//2, LARGE_X*2))) + c = nd.reshape_like(a, b) + assert c.shape == (SMALL_Y//2, LARGE_X*2) + + +def test_flatten(): + a = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y).reshape((LARGE_X//2, 2, SMALL_Y)) + b = nd.flatten(a) + assert b[-1][-1] == (LARGE_X-1) + assert b[-1][0] == (LARGE_X-2) + assert b.shape == (LARGE_X//2, SMALL_Y*2) + + +def test_expand_dims(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.expand_dims(a, axis=1) + nd.waitall() + assert b.shape == (SMALL_Y, 1, LARGE_X) + + +def test_concat(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.array(np.zeros((SMALL_Y, LARGE_X))) + c = nd.concat(a,b, dim=0) + assert c.shape == (b.shape[0]*2, LARGE_X) + + +def test_stack(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.array(np.zeros((SMALL_Y, LARGE_X))) + c = nd.stack(a,b, axis=1) + assert c.shape == (b.shape[0], 2, LARGE_X) + + +def test_broadcast_axes(): + a = create_2d_tensor(rows=1, columns=LARGE_X) + b = nd.broadcast_axis(a, axis=[0], size=2) + assert b.shape == (a.shape[0]*2, a.shape[1]) + + +def test_sum(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.sum(a, axis=1) + assert b.shape[0] == SMALL_Y + + +def test_prod(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.prod(a, axis=1) + assert b.shape[0] == SMALL_Y + + +def test_mean(): + a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + b = nd.mean(a, axis=0) + assert b[0] == (SMALL_Y/2-1) + + +def test_min(): + a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + b = nd.min(a, axis=0) + assert b[0] == 0 + assert b[-1] == 0 + + +def test_max(): + a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + b = nd.max(a, axis=0) + assert b[0] == (SMALL_Y-1) + assert b[-1] == (SMALL_Y-1) + + +def test_norm(): + a = np.array(np.full((1, LARGE_X), 3)) + b = np.array(np.full((1, LARGE_X), 4)) + c = nd.array(np.concatenate((a,b), axis=0)) + d = nd.norm(c, ord=2, axis=0) + e = nd.norm(c, ord=1, axis=0) + assert d.shape[0] == LARGE_X + assert e.shape[0] == LARGE_X + assert d[-1] == 5 + assert e[-1] == 7 + + +def test_argmax(): + a = np.ones((SMALL_Y, LARGE_X)) + b = np.zeros((SMALL_Y, LARGE_X)) + c = nd.array(np.concatenate((a,b), axis=0)) + d = nd.argmax(c, axis=0) + assert d.shape[0] == LARGE_X + assert d[-1] == d[0] == 0 + + +def test_relu(): + def frelu(x): + return np.maximum(x, 0.0) + def frelu_grad(x): + return 1.0 * (x > 0.0) + shape = (SMALL_Y, LARGE_X) + x = mx.symbol.Variable("x") + y = mx.sym.relu(x) + xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + eps = 1e-4 + xa[abs(xa) < eps] = 1.0 + ya = frelu(xa) + ga = frelu_grad(xa) + check_symbolic_forward(y, [xa], [ya]) + + +def test_sigmoid(): + def fsigmoid(a): + return np.divide(1.0, (1.0 + np.exp(-a))) + shape = (SMALL_Y, LARGE_X) + x = mx.symbol.Variable("x") + y = mx.sym.sigmoid(x) + xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + ya = fsigmoid(xa) + check_symbolic_forward(y, [xa], [ya]) + + +def np_softmax(x, axis=-1, temperature=1.0): + x = x - np.max(x, axis=axis, keepdims=True) + x = np.exp(x/temperature) + x /= np.sum(x, axis=axis, keepdims=True) + return x + + +def test_log_softmax(): + ndim = 2 + shape = (SMALL_Y, LARGE_X) + axis = np.random.randint(0, ndim) + data = np.random.uniform(-2, 2, size=shape) + sym = mx.sym.log_softmax(axis=axis-ndim) + check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)]) + + +def test_iadd(): + a = nd.array(np.ones((SMALL_Y, LARGE_X))) + b = nd.array(np.ones((SMALL_Y, LARGE_X))) + c = b + c += a + assert c.shape == a.shape + assert c[0][-1] == 2 + + +def test_isub(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + b = nd.array(np.ones((SMALL_Y, LARGE_X))) + c = a + c -= b + assert c.shape == a.shape + assert c[0][-1] == 2 + + +def test_imul(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + b = nd.array(np.ones((SMALL_Y, LARGE_X))) + c = b + c *= a + assert c.shape == a.shape + assert c[0][-1] == 3 + + +def test_idiv(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 4))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + c = a + c /= b + assert c.shape == a.shape + assert c[0][-1] == 2 + + +def test_imod(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + c = a + c %= b + assert c.shape == a.shape + assert c[0][-1] == 1 + + +def test_eq(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + c = (a == b) + assert np.sum(c[0].asnumpy() == 1).all() + + +def test_neq(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + c = (a != b) + assert np.sum(c[0].asnumpy() == 1).all() + + +def test_lt(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + d = (a <= b) + assert np.sum(d[0].asnumpy() == 1).all() + + +def test_lte(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + c = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + d = (a <= b) + e = (a <= c) + assert np.sum(d[0].asnumpy() == 1).all() + assert np.sum(e[0].asnumpy() == 1).all() + + +def test_gt(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + d = (a >= b) + assert np.sum(d[0].asnumpy() == 1).all() + + +def test_gte(): + a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2))) + c = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3))) + d = (a >= b) + e = (a >= c) + assert np.sum(d[0].asnumpy() == 1).all() + assert np.sum(e[0].asnumpy() == 1).all() + + +def test_slice_like(): + a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + b = nd.array(np.ones((SMALL_Y//2, LARGE_X//2))) + c = nd.slice_like(a, b) + d = nd.slice_like(a, b, axes=(0)) + e = nd.slice_like(a, b, axes=(-1)) + assert c.shape == b.shape + assert d.shape[0] == b.shape[0] + assert e.shape[-1] == b.shape[-1] + assert c[0][-1] == 0 + assert d[-1][0] == (SMALL_Y//2-1) + assert e[-1][-1] == (SMALL_Y-1) + + +def test_slice_axis(): + a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + c = nd.slice_axis(a, axis=0, begin=0, end=SMALL_Y//2) + d = nd.slice_axis(a, axis=1, begin=0, end=LARGE_X//2) + assert c.shape[0] == a.shape[0]//2 + assert d.shape[1] == a.shape[1]//2 + assert c[-1][0] == (SMALL_Y//2-1) + assert d[-1][-1] == (SMALL_Y-1) + + +def test_one_hot(): + a = nd.array(np.zeros(SMALL_Y)) + a[0] = 1 + a[-1] = 1 + b = nd.one_hot(a, LARGE_X) + b[0][1] == 1 + b[-1][1] == 1 + + +def test_full(): + a = nd.full((SMALL_Y, LARGE_X), 3) + assert a.shape == (SMALL_Y, LARGE_X) + assert a[SMALL_Y//2][LARGE_X//2] == 3 + assert a[-1][-1] == 3 + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 3a66500957e0..7ea27c221abf 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -25,6 +25,7 @@ # dimension constants LARGE_X = 5000000000 MEDIUM_X = 1000000000 +SMALL_Y = 1 def test_slice(): @@ -33,6 +34,323 @@ def test_slice(): assert res.shape[0] == MEDIUM_X +def test_gluon_embedding(): + m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X) + m.initialize() + a = nd.zeros((MEDIUM_X, SMALL_Y)) + b = m(a) + assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X) + assert b.asnumpy().size == LARGE_SIZE + + +def test_ndarray_zeros(): + a = nd.zeros(shape=LARGE_X) + assert a[-1] == 0 + assert a.shape == (LARGE_X,) + assert a.size == LARGE_X + + +def test_ndarray_ones(): + a = nd.ones(shape=(LARGE_X)) + assert a[-1][0] == 1 + assert nd.sum(a).asnumpy() == LARGE_X + + +@unittest.skip("need to fix") +def test_ndarray_convert(): + a = nd.zeros(shape=(LARGE_X, SMALL_Y)) + b = a.astype(np.int32) + b.wait_to_read() + assert b.dtype == np.int32 + b = a.tostype('row_sparse') + b.wait_to_read() + assert isinstance(b, mx.nd.sparse.RowSparseNDArray) + + +@with_seed() +def test_ndarray_random_uniform(): + a = nd.random.uniform(shape=LARGE_X) + assert a[-1][0] != 0 + + +@with_seed() +def test_ndarray_random_randint(): + a = nd.random.randint(100, 10000, shape=LARGE_X) + assert a.shape == (LARGE_X,) + # check if randint can generate value greater than 2**32 (large) + low_large_value = 2**32 + high_large_value = 2**34 + a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64) + low = mx.nd.array([low_large_value], dtype='int64') + high = mx.nd.array([high_large_value], dtype='int64') + assert a.__gt__(low) and a.__lt__(high) + + +def test_ndarray_empty(): + a = nd.empty(LARGE_X) + assert a.shape == (LARGE_X,) + + +def test_elementwise(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) + res = a + b + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + res = a + 1 + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + res = nd.sqrt(a + 3) + assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] + + +def test_reduce(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1] + + +@unittest.skip("need to fix") +def test_dot(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(SMALL_Y, SMALL_Y)) + res = nd.dot(a, b) + assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + + +def test_FullyConnected(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.ones(shape=(SMALL_Y, SMALL_Y)) + res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True) + assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] + + +def test_broadcast(): + a = nd.ones(shape=(LARGE_X, SMALL_Y*2)) + b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) + res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y*2)) + assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] + res = mx.nd.broadcast_like(b, a) + assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1] + + +def test_clip(): + a = nd.arange(0, LARGE_X) + res = nd.clip(a, a_min=100, a_max=1000) + assert np.sum(res[-1].asnumpy() == 1000) == 101 + + +def test_argmin(): + a = nd.arange(0, LARGE_X) + idx = mx.nd.argmin(a, axis=0) + assert idx.shape[0] == SMALL_Y + + +def test_tile(): + a = nd.arange(0, LARGE_X) + b = nd.tile(a, reps=(1,2)) + assert b[0][LARGE_X] == b[0][0] + assert b[0][LARGE_X-1] == b[0][-1] + + +def test_take(): + a = nd.ones(shape=LARGE_X) + idx = nd.arange(LARGE_X - 1000, LARGE_X) + res = nd.take(a, idx) + assert np.sum(res.asnumpy() == 1) == res.shape[0] + + +def test_slice(): + a = nd.ones(shape=(2, LARGE_X)) + res = nd.slice(a, begin=(1, LARGE_X-1000000000), end=(2, LARGE_X)) + assert np.sum(res[-1].asnumpy() == 1) == res.shape[1] + + +def test_slice_assign(): + a = nd.ones(shape=LARGE_X) + a[LARGE_X-1:LARGE_X] = 1000 + assert np.sum(a[-1].asnumpy() == 1000) == 1 + + +def test_expand_dims(): + a = nd.ones(shape=LARGE_X) + res = nd.expand_dims(a, axis=0) + assert res[0][0] == 1 + assert res.shape == (1, a.shape[0]) + + +def test_squeeze(): + a = nd.ones(shape=LARGE_X) + data = nd.expand_dims(a, axis=0) + res = nd.squeeze(data) + assert a[0] == res[0] + assert res.shape == a.shape + + +def test_broadcast_div(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) * 2 + res = a / b + assert np.sum(res.asnumpy() == 0.5) == a.shape[0] + + +def test_Dense(ctx=mx.cpu(0)): + data = mx.nd.ones(shape=LARGE_X) + linear = gluon.nn.Dense(2) + linear.initialize(ctx=ctx) + res = linear(data) + res.wait_to_read() + assert res.shape == (LARGE_X, 2) + + +@unittest.skip("need to fix") +def test_where(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y) + res = nd.where(b > 100, a, b) + assert np.sum(res[-1].asnumpy() == 1) == b.shape[1] + csr_cond = nd.sparse.cast_storage(b < 10, 'csr') + res = nd.sparse.where(csr_cond, a, b) + assert np.sum(res[0].asnumpy() == 1) == 10 + + +def test_pick(): + a = mx.nd.ones(shape=(LARGE_X, 2)) + b = mx.nd.ones(shape=LARGE_X) + res = mx.nd.pick(a, b) + assert res.shape == b.shape + + +def test_depthtospace(): + def numpy_depth_to_space(x, blocksize): + b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) + tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) + y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) + return y + + shape_inp = (LARGE_X, 4, 1, 1) + data = rand_ndarray(shape_inp, 'default') + data_np = data.asnumpy() + expected = numpy_depth_to_space(data_np, 2) + output = mx.nd.depth_to_space(data, 2) + assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + + +def test_spacetodepth(): + def numpy_space_to_depth(x, blocksize): + b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] + tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]) + tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) + y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize]) + return y + + shape_inp = (LARGE_X, 1, 2, 2) + data = rand_ndarray(shape_inp, 'default') + data_np = data.asnumpy() + expected = numpy_space_to_depth(data_np, 2) + output = mx.nd.space_to_depth(data, 2) + assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + +@with_seed() +def test_diag(): + a_np = np.random.random((LARGE_X, 2)).astype(np.float32) + a = mx.nd.array(a_np) + + # k == 0 + r = mx.nd.diag(a) + assert_almost_equal(r.asnumpy(), np.diag(a_np)) + + # k == 1 + k = 1 + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + # k == -1 + k = -1 + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + +@with_seed() +def test_ravel_multi_index(): + x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, SMALL_Y, 4) + x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, SMALL_Y, 3) + x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, SMALL_Y, 2) + indices_2d = [[x1, x2, x3], [y1, y2, y3]] + idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, 5)) + idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, 5)) + assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3 + + +@with_seed() +def test_unravel_index(): + x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, SMALL_Y, 4) + x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, SMALL_Y, 3) + x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, SMALL_Y, 2) + original_2d_indices = [[x1, x2, x3], [y1, y2, y3]] + idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, 5)) + indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, 5)) + assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() + + +def create_large_vector(size, dtype=np.int64): + a = nd.arange(0, size, dtype=dtype) + # Implicitly calling nd.waitall() + assert a[0] == 0 + return a + + +def test_transpose(): + b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) + t = b.T + assert t.shape == (LARGE_X, 1) + assert t[-1, 0].asnumpy() == (LARGE_X - 1) + + +def test_swapaxes(): + b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(LARGE_X, 1) + t = nd.swapaxes(b, dim1=0, dim2=1) + assert t.shape == (1, LARGE_X) + assert t[0, -1].asnumpy() == (LARGE_X - 1) + + +def test_flip(): + b = nd.arange(0, LARGE_X, dtype=np.int64).reshape(1, LARGE_X) + t = nd.flip(b, axis=0) + assert t.shape == (LARGE_X, 1) + assert t[-1, :].asnumpy() == 0 + + +def test_softmax(): + input_data = mx.nd.ones(2, LARGE_X) + true_output = np.full(LARGE_X, 0.5) + output = nd.softmax(input_data, axis=0) + assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) + + +def test_argsort(): + b = create_large_vector(size=LARGE_X) + s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64) + mx.nd.waitall() + assert (s[0].asnumpy() == (LARGE_X - 1)).all() + + +def test_sort(): + b = create_large_vector(size=LARGE_X) + s = nd.sort(b, axis=0, is_ascend=False) + assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all() + s = nd.sort(b, is_ascend=True) + assert np.sum(s[0].asnumpy() == 0).all() + + +def test_topk(): + b = create_large_vector(size=LARGE_X) + k = nd.topk(b, k=10, axis=0, dtype=np.int64) + assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y + ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) + assert np.all(ind == val) + l = nd.topk(b, k=1, axis=0, dtype=np.int64, ret_typ="value") + assert l.sum() == (LARGE_X - 1) + + if __name__ == '__main__': import nose nose.runmodule()