From 5fcb40ca17b667d8a0873b185f36de02f84410e8 Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Wed, 4 Sep 2019 10:48:17 -0700 Subject: [PATCH] Add Large tensor vector test cases (#15941) * add random ops * add shuffle to test large array * shape evaluation after value check * add log, exponent, power ops * fix sequence reverse issue in test_large_array and add sequence ops to test_large_vector * add binary arithmetic * fix lint, minor mistakes in large_array; add nn op to tensor * Trigger notification coz of test_operator.test_laop_6 error * Trigger notification coz of test_operator.test_laop_6 error * Trigger notification bcoz R failures * address comments * normal distribution assert statement fix; randint dtype check * correct layernorm and shuffle * layer norm numpy flaky hence removed, dropout shape fix * comment not working ops * fix multi * Trigger notification * fix seq reverse, uncomment seq mask as it works * index fix and uncomment test * index fix * seq_reverse index fix * uncomment seq reverse test and handle static typecasts * removing commented ops * resolve merge conflict * teardown, lint, remove redundant functions * fix shape assertions and randint low,high * remove waits, add teardown to large_array, change randint assert in large array --- src/operator/sequence_last-inl.h | 24 +- src/operator/sequence_reverse-inl.h | 14 +- tests/nightly/test_large_array.py | 67 +++--- tests/nightly/test_large_vector.py | 332 +++++++++++++++++++++++++++- 4 files changed, 388 insertions(+), 49 deletions(-) diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index 4c42934f1618..3c3c8b0cd49e 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -66,24 +66,24 @@ struct SequenceLastParam : public dmlc::Parameter { template struct SequenceLastKernel { template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, - const IType *idx, int offset1, int offset2, + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in, + const IType *idx, index_t offset1, index_t offset2, mshadow::Shape<2> oshape) { const auto opos = mxnet_op::unravel(i, oshape); - const int seqpos = static_cast(idx[opos[0]]) - 1; - const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1]; + const index_t seqpos = static_cast(idx[opos[0]]) - 1; + const index_t ipos = seqpos * offset1 + opos[0] * offset2 + opos[1]; KERNEL_ASSIGN(out[i], req, in[ipos]); } }; struct SequenceLastGradKernel { template - MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad, - const IType *idx, int offset1, int offset2, + MSHADOW_XINLINE static void Map(index_t i, DType *in_grad, const DType *out_grad, + const IType *idx, index_t offset1, index_t offset2, mshadow::Shape<2> oshape) { const auto opos = mxnet_op::unravel(i, oshape); - const int seqpos = static_cast(idx[opos[0]]) - 1; - const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1]; + const index_t seqpos = static_cast(idx[opos[0]]) - 1; + const index_t ipos = seqpos * offset1 + opos[0] * offset2 + opos[1]; in_grad[ipos] += out_grad[i]; } }; @@ -103,8 +103,8 @@ class SequenceLastOp : public Operator { int axis = param_.axis; int out_size = out.size(0) * out.size(1); int max_seq_len = data.size(axis); - int offset1 = axis ? out.size(1) : out_size; - int offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1); + index_t offset1 = axis ? out.size(1) : out_size; + index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1); MXNET_ASSIGN_REQ_SWITCH(req, req_type, { mxnet_op::Kernel, xpu>::Launch( @@ -126,8 +126,8 @@ class SequenceLastOp : public Operator { int out_size = batch * rest; int max_seq_len = in_grad.size(axis); - int offset1 = axis ? rest : out_size; - int offset2 = axis ? (max_seq_len * rest) : rest; + index_t offset1 = axis ? rest : out_size; + index_t offset2 = axis ? (max_seq_len * rest) : rest; mxnet_op::Kernel::Launch( s, out_size, in_grad.dptr_, out_grad.dptr_, indices.dptr_, offset1, diff --git a/src/operator/sequence_reverse-inl.h b/src/operator/sequence_reverse-inl.h index 8e2362f76dd2..e857c6ab9af4 100644 --- a/src/operator/sequence_reverse-inl.h +++ b/src/operator/sequence_reverse-inl.h @@ -67,30 +67,30 @@ struct SequenceReverseParam : public dmlc::Parameter { template struct ReverseKernel { template - MSHADOW_XINLINE static void Map(const int i, DType *const out_data, + MSHADOW_XINLINE static void Map(const index_t i, DType *const out_data, const DType *const in_data, const index_t max_seq_len, const index_t batch_size, const index_t other_dim, const index_t numel, const IType *const indices) { const index_t batch = i / (max_seq_len * other_dim); - const int id = (i / other_dim) % max_seq_len; + const index_t id = (i / other_dim) % max_seq_len; const index_t j = i % other_dim; const index_t num_seq = indices ? static_cast(indices[batch]) : max_seq_len; const index_t padded_periods = max_seq_len - num_seq; // padded part - if (padded_periods > 0 && id < static_cast(padded_periods)) { - const int padded_in_offset = + if (padded_periods > 0 && id < padded_periods) { + const index_t padded_in_offset = (id + num_seq) * batch_size * other_dim + batch * other_dim; KERNEL_ASSIGN(out_data[padded_in_offset + j], req, in_data[padded_in_offset + j]); } // unpadded part - if (id < static_cast(num_seq)) { - const int in_offset = id * batch_size * other_dim + batch * other_dim; - const int out_offset = + if (id < num_seq) { + const index_t in_offset = id * batch_size * other_dim + batch * other_dim; + const index_t out_offset = numel - (id + 1 + padded_periods) * batch_size * other_dim + batch * other_dim; diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index cdacce91ab6e..7622b76a3120 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -21,7 +21,7 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor from mxnet import gluon, nd -from tests.python.unittest.common import with_seed +from tests.python.unittest.common import with_seed, teardown # dimension constants MEDIUM_X = 10000 @@ -56,10 +56,8 @@ def test_ndarray_ones(): def test_ndarray_convert(): a = nd.zeros(shape=(LARGE_X, SMALL_Y)) b = a.astype(np.int32) - b.wait_to_read() assert b.dtype == np.int32 b = a.tostype('row_sparse') - b.wait_to_read() assert isinstance(b, mx.nd.sparse.RowSparseNDArray) @@ -79,15 +77,16 @@ def test_ndarray_random_randint(): a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64) low = mx.nd.array([low_large_value], dtype='int64') high = mx.nd.array([high_large_value], dtype='int64') - assert a.__gt__(low) and a.__lt__(high) + assert a >= low and a < high + assert a[-1][0].dtype == np.int64 @with_seed() def test_ndarray_random_exponential(): scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) a = nd.random.exponential(scale=scale_array, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) assert a[-1][0][0][0] >= 0 + assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) @with_seed() @@ -96,8 +95,8 @@ def test_ndarray_random_gamma(): beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) a = nd.random.gamma(alpha=alpha_array, beta=beta_array, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) assert a[-1][0][0][0] >= 0 + assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) @with_seed() @@ -105,16 +104,16 @@ def test_ndarray_random_multinomial(): # test 1 shape dimension probs = nd.random.uniform(shape=(LARGE_X, SMALL_Y)) a = nd.random.multinomial(probs) - assert a.shape == (LARGE_X,) assert a[-1] >= 0 + assert a.shape == (LARGE_X,) # test for NDArray multi-dimension shape a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (LARGE_X, SMALL_X, SMALL_Y) assert a[-1][0][0] >= 0 + assert a.shape == (LARGE_X, SMALL_X, SMALL_Y) # test log_likelihood output shape a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y), get_prob=True) - assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape assert a[-1][0][0] >= 0 + assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape @with_seed() @@ -123,8 +122,8 @@ def test_ndarray_random_generalized_negative_binomial(): mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) a = nd.random.generalized_negative_binomial(mu=mu_array, alpha=alpha_array, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) assert a[-1][0][0][0] >= 0 + assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) @with_seed() @@ -133,8 +132,8 @@ def test_ndarray_random_negative_binomial(): p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) a = nd.random.negative_binomial(k=k_array, p=p_array, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) assert a[-1][0][0][0] >= 0 + assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) @with_seed() @@ -144,26 +143,40 @@ def test_ndarray_random_normal(): a = nd.random.normal(loc=loc_array, scale=scale_array, shape=(SMALL_X, SMALL_Y)) assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) - assert a[-1][0][0][0] >= 0 @with_seed() def test_ndarray_random_poisson(): lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y)) a = nd.random.poisson(lam=lambda_array, shape=(SMALL_X, SMALL_Y)) - assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) assert a[-1][0][0][0] >= 0 + assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y) @with_seed() def test_ndarray_random_randn(): a = nd.random.randn(LARGE_X, SMALL_Y) assert a.shape == (LARGE_X, SMALL_Y) - assert a[-1][0] >= 0 - # TODO: Once PR for randn ndarray dtype for loc,scale param merged + # TODO: Once PR #15772 for randn ndarray dtype for loc,scale param merged # Add check for (x,y,m,n) where x,y shape of loc,scale and m,n input shape +@with_seed() +def test_ndarray_random_shuffle(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + a[-1] == 3 # assign 3 to entire last row + a = nd.random.shuffle(a) + # slice first column from shuffled array + # pass LARGE_X values to numpy instead of LARGE_X*SMALL_Y + # could have assigned to last column (so as to pass SMALL_Y) + # but shuffle operation is performed along first axis + unique_a = np.unique(a[:, 0].asnumpy()) + assert len(unique_a) == 2 # only 2 unique values + assert unique_a[0] == 1 # first unique value is 1 + assert unique_a[1] == 3 # second unique value is 3 + assert a.shape[0] == (LARGE_X, SMALL_Y) + + def test_ndarray_empty(): a = nd.empty((LARGE_X, SMALL_Y)) assert a.shape == (LARGE_X, SMALL_Y) @@ -277,7 +290,6 @@ def test_Dense(ctx=mx.cpu(0)): linear = gluon.nn.Dense(100) linear.initialize(ctx=ctx) res = linear(data) - res.wait_to_read() assert res.shape == (50000000, 100) @@ -386,22 +398,22 @@ def test_unravel_index(): def test_transpose(): b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) t = b.T - assert t.shape == (SMALL_Y, LARGE_X) assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1] + assert t.shape == (SMALL_Y, LARGE_X) def test_swapaxes(): b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) t = nd.swapaxes(b, dim1=0, dim2=1) - assert t.shape == (SMALL_Y, LARGE_X) assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1] + assert t.shape == (SMALL_Y, LARGE_X) def test_flip(): b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) t = nd.flip(b, axis=0) - assert t.shape == (LARGE_X, SMALL_Y) assert np.sum(t[-1, :].asnumpy() == 0) == b.shape[1] + assert t.shape == (LARGE_X, SMALL_Y) def test_softmax(): @@ -535,7 +547,9 @@ def test_sequence_reverse(): assert b.shape == a.shape # test with sequence length - b = nd.SequenceReverse(a, sequence_length=[2, 3]) + # 2 rows of batch 1 and 3 rows of batch 2 reversed + b = nd.SequenceReverse(a, sequence_length=nd.array([2, 3]), + use_sequence_length=True) assert b[1][0][0] == a[0][0][0] # check if reversed assert b[-1][0][0] == a[-1][0][0] # check if intact assert b.shape == a.shape @@ -546,7 +560,7 @@ def test_sequence_last(): # test if returns last sequence b = nd.SequenceLast(a) - assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor + assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2,SMALL_Y) tensor assert b.shape == (2, SMALL_Y) # test with sequence length @@ -715,17 +729,18 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, forward_check_eps) + # TODO: correctness of dropout # currently only test for dropout to work # since testing for correctness involves flakiness issue #14288 def test_dropout(): - shape = (10, 10) + shape = (LARGE_X, SMALL_Y) x = mx.sym.var('data') y = mx.sym.Dropout(x, p=1, cudnn_off=True) exe = y.simple_bind(ctx=default_context(), data=shape) exe.arg_arrays[0][:] = 1 out = exe.forward(is_train=True) - out[0].wait_to_read() + assert out.shape == out.shape def test_activation(): @@ -736,7 +751,7 @@ def test_activation(): # Hyperbolic tangent (tanh) # y = (exp(x)-exp(-x))/(exp(x)+exp(-x)) a = mx.nd.Activation(a, act_type="tanh") - tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2)) + tanh_x = (np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x)) assert a[-1][-1] == tanh_x # Recitified Linear Unit (relu) @@ -763,8 +778,6 @@ def test_activation(): def test_batchnorm(): shape = (LARGE_X, SMALL_Y) axis = 1 # default - expand_shape = [1] * len(shape) - expand_shape[axis] = shape[axis] nch = shape[axis] data = mx.nd.ones(shape=shape) @@ -775,7 +788,7 @@ def test_batchnorm(): output = mx.nd.BatchNorm(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var) - output.wait_to_read() + assert output.shape == shape def test_add(): diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 4e1b48c8d047..34cc368d18fe 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. +import math import numpy as np import mxnet as mx from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, create_vector from mxnet import gluon, nd -from tests.python.unittest.common import with_seed +from tests.python.unittest.common import with_seed, teardown # dimension constants LARGE_X = 5000000000 @@ -63,7 +64,7 @@ def test_ndarray_random_randint(): a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64) low = mx.nd.array([low_large_value], dtype='int64') high = mx.nd.array([high_large_value], dtype='int64') - assert a > low and a < high + assert a >= low and a < high def test_ndarray_empty(): @@ -141,7 +142,6 @@ def test_Dense(ctx=mx.cpu(0)): linear = gluon.nn.Dense(2) linear.initialize(ctx=ctx) res = linear(data) - res.wait_to_read() assert res.shape == (LARGE_X, 2) @@ -169,6 +169,332 @@ def test_topk(): assert val.sum() == (LARGE_X - 1) +@with_seed() +def test_ndarray_random_exponential(): + a = nd.random.exponential(shape=LARGE_X) + assert a[-1] >= 0. + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_gamma(): + a = nd.random.gamma(shape=LARGE_X) + assert a[-1] >= 0. + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_generalized_negative_binomial(): + a = nd.random.generalized_negative_binomial(shape=LARGE_X) + assert a[-1] >= 0. + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_multinomial(): + a = nd.random.multinomial(nd.random.uniform(shape=LARGE_X)) + assert a[-1] >= 0. + assert a.shape[0] == 1 + + +@with_seed() +def test_ndarray_random_negative_binomial(): + a = nd.random.negative_binomial(shape=LARGE_X) + assert a[-1] >= 0. + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_normal(): + a = nd.random.normal(shape=LARGE_X) + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_poisson(): + a = nd.random.poisson(shape=LARGE_X) + assert a[-1] >= 0. + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_randn(): + a = nd.random.randn(LARGE_X) + assert a.shape[0] == LARGE_X + + +@with_seed() +def test_ndarray_random_shuffle(): + a = nd.ones(shape=LARGE_X) + a[-1] = 3 + a = nd.random.shuffle(a) + unique_a = np.unique(a.asnumpy()) + assert len(unique_a) == 2 # only 2 unique values + assert unique_a[0] == 1 # first unique value is 1 + assert unique_a[1] == 3 # second unique value is 3 + assert a.shape[0] == LARGE_X + + +def test_exponent_logarithm_operators(): + a = 2*nd.ones(shape=LARGE_X) + # exponent + result = nd.exp(a) + assert result[-1] == 7.389056 + assert result.shape == a.shape + + # exponent minus 1 + result = nd.expm1(a) + assert result[-1] == 6.389056 + assert result.shape == a.shape + + # log2 + result = nd.log2(a) + assert result[-1] == 1 + assert result.shape == a.shape + + # log10 + result = nd.log10(a) + assert result[-1] == 0.30103 + assert result.shape == a.shape + + # log1p + result = nd.log1p(a) + assert result[-1] == 1.0986123 + assert result.shape == a.shape + + # log + result = nd.log(a) + assert result[-1] == 0.6931472 + assert result.shape == a.shape + + +def test_power_operators(): + a = 2*nd.ones(shape=LARGE_X) + # sqrt + result = nd.sqrt(a) + assert result[-1] == 1.4142135 + assert result.shape == a.shape + + # rsqrt + result = nd.rsqrt(a) + assert result[-1] == 0.70710677 + assert result.shape == a.shape + + # cbrt + result = nd.cbrt(a) + assert result[-1] == 1.2599211 + assert result.shape == a.shape + + # rcbrt + result = nd.rcbrt(a) + assert result[-1] == 0.7937005 + assert result.shape == a.shape + + # square + result = nd.square(a) + assert result[-1] == 4 + assert result.shape == a.shape + + # reciprocal + result = nd.reciprocal(a) + assert result[-1] == 0.5 + assert result.shape == a.shape + + +def test_sequence_mask(): + # Sequence Mask input [max_sequence_length, batch_size] + # test with input batch_size = 2 + a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2) + + # test as identity operator + b = nd.SequenceMask(a) + assert b[-1][0] == a[-1][0] + assert b.shape == a.shape + + # test with default mask + b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]), + use_sequence_length=True) + assert b[0][1] == a[0][1] # first sequence of each batch kept + assert b[-1][-1] != a[-1][-1] # rest sequences masked + assert b[-1][-1] == 0 + + # test with mask value + b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]), + use_sequence_length=True, value=-1) + assert b[-1][-1] == -1 + + +def test_sequence_reverse(): + a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2) + # test as reverse operator + b = nd.SequenceReverse(a) + assert b[-1][0] == a[0][0] + assert b.shape == a.shape + + # test with sequence length + b = nd.SequenceReverse(a, sequence_length=nd.array([2, 3]), + use_sequence_length=True) + assert b[1][0] == a[0][0] # check if reversed + assert b[-1][0] == a[-1][0] # check if intact + assert b.shape == a.shape + + +def test_sequence_last(): + a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2) + + # test if returns last sequence + b = nd.SequenceLast(a) + assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) + assert b.shape == (2,) + + # test with sequence length + # parameter sequence_length - NDArray with shape (batch_size) + # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 + b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]), + use_sequence_length=True) + # check if it takes 2nd sequence from the first batch + assert b[0] == a[1][0] + + +# TODO: correctness of layernorm +# numpy implementation for large vector is flaky +def test_layer_norm(): + axis = 0 + eps = 1E-5 + in_shape = LARGE_X + + data = nd.random.normal(0, 1, in_shape) + gamma = nd.random.normal(0, 1, in_shape) + beta = nd.random.normal(0, 1, in_shape) + mx_out = nd.LayerNorm(data, gamma, beta, axis, eps) + assert mx_out.shape == (in_shape,) + + +# TODO: correctness of batchnorm +# in future, we could test if mean, var of output +# matches target output's mean, var +def test_batchnorm(): + shape = LARGE_X + axis = 0 # since vector + + data = mx.nd.ones(shape=shape) + bn_gamma = mx.nd.random.uniform(shape=shape) + bn_beta = mx.nd.random.uniform(shape=shape) + bn_running_mean = mx.nd.zeros(shape) + bn_running_var = mx.nd.ones(shape) + + output = mx.nd.BatchNorm(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var, axis=axis) + assert output.shape == (shape,) + + +def test_add(): + a = nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) + c = b + c = c.__add__(a) + assert c[-1] == 2 + assert c.shape == a.shape + + +def test_sub(): + a = 3*nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) + c = b + c = c.__sub__(a) + assert c[-1] == -2 + assert c.shape == a.shape + + +def test_rsub(): + a = 3*nd.ones(shape=LARGE_X) + b = nd.ones(shape=LARGE_X) + c = b + c = c.__rsub__(a) + assert c[-1] == 2 + assert c.shape == a.shape + + +def test_neg(): + a = nd.ones(shape=LARGE_X) + c = a + c = c.__neg__() + assert c[-1] == -1 + assert c.shape == a.shape + + +def test_mul(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__mul__(a) + assert c[-1] == 6 + assert c.shape == a.shape + + +def test_div(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__div__(a) + assert c[-1] == 3/2 + assert c.shape == a.shape + + +def test_rdiv(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__rdiv__(a) + assert c[-1] == 2/3 + assert c.shape == a.shape + + +def test_mod(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__mod__(a) + assert c[-1] == 1 + assert c.shape == a.shape + + +def test_rmod(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__rmod__(a) + assert c[-1] == 2 + assert c.shape == a.shape + + +def test_imod(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__imod__(a) + assert c[-1] == 1 + assert c.shape == a.shape + + +def test_pow(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__pow__(a) + assert c[-1] == 9 + assert c.shape == a.shape + + +def test_rpow(): + a = 2*nd.ones(shape=LARGE_X) + b = 3*nd.ones(shape=LARGE_X) + c = b + c = c.__rpow__(a) + assert c[-1] == 8 + assert c.shape == a.shape + + def test_shape(): b = create_vector(size=LARGE_X) #explicit wait_to_read()