From ff2cc54173f4bdddb579b3a9d73d2c5fa5be0ea1 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 9 Oct 2019 16:49:04 +0000 Subject: [PATCH 1/3] adding tests to very large tensor support for more operators --- tests/nightly/test_large_array.py | 244 ++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index e51e220c232f..e04edbb46dd1 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1199,6 +1199,250 @@ def test_full(): assert a[-1][-1] == 3 +def test_astype(): + x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + y = x.astype('int32') + assert y.dtype == np.int32 + assert y[-1][-1] == SMALL_Y-1 + + +def test_cast(): + x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) + y = nd.cast(x, np.int32) + assert y.dtype == np.int32 + assert y[-1][-1] == SMALL_Y-1 + + +def test_repeat(): + x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X//2) + y = nd.repeat(x, repeats=2, axis = 1) + assert y.shape == (SMALL_Y, LARGE_X) + assert y[0][1] == 0 + assert y[-1][-1] == SMALL_Y-1 + x = create_2d_tensor(rows=SMALL_Y//2, columns=LARGE_X) + y = nd.repeat(x, repeats=2, axis = 0) + assert y.shape == (SMALL_Y, LARGE_X) + assert y[0][1] == 0 + assert y[-1][0] == SMALL_Y//2-1 + + +def create_input_for_rounding_ops(): + inp = nd.arange(-LARGE_X//2, LARGE_X//2, dtype=np.float64).reshape(1, LARGE_X) + inp = inp/2 + inp = nd.broadcast_to(inp, (SMALL_Y, LARGE_X)) + return inp + + +def test_ceil(): + x = create_input_for_rounding_ops() + y = nd.ceil(x) + assert y[1][LARGE_X//2-2] == -1 + assert y[1][LARGE_X//2-1] == 0 + assert y[1][LARGE_X//2] == 0 + assert y[1][LARGE_X//2+1] == 1 + assert y[1][LARGE_X//2+2] == 1 + + +def test_fix(): + x = create_input_for_rounding_ops() + y = nd.fix(x) + assert y[1][LARGE_X//2-2] == -1 + assert y[1][LARGE_X//2-1] == 0 + assert y[1][LARGE_X//2] == 0 + assert y[1][LARGE_X//2+1] == 0 + assert y[1][LARGE_X//2+2] == 1 + + +def test_floor(): + x = create_input_for_rounding_ops() + y = nd.floor(x) + assert y[1][LARGE_X//2-2] == -1 + assert y[1][LARGE_X//2-1] == -1 + assert y[1][LARGE_X//2] == 0 + assert y[1][LARGE_X//2+1] == 0 + assert y[1][LARGE_X//2+2] == 1 + + +def test_rint(): + x = create_input_for_rounding_ops() + y = nd.rint(x) + assert y[1][LARGE_X//2-2] == -1 + assert y[1][LARGE_X//2-1] == -1 + assert y[1][LARGE_X//2] == 0 + assert y[1][LARGE_X//2+1] == 0 + assert y[1][LARGE_X//2+2] == 1 + + +def test_round(): + x = create_input_for_rounding_ops() + y = nd.round(x) + assert y[1][LARGE_X//2-2] == -1 + assert y[1][LARGE_X//2-1] == -1 + assert y[1][LARGE_X//2] == 0 + assert y[1][LARGE_X//2+1] == 1 + assert y[1][LARGE_X//2+2] == 1 + + +def test_trunc(): + x = create_input_for_rounding_ops() + y = nd.trunc(x) + assert y[1][LARGE_X//2-2] == -1 + assert y[1][LARGE_X//2-1] == 0 + assert y[1][LARGE_X//2] == 0 + assert y[1][LARGE_X//2+1] == 0 + assert y[1][LARGE_X//2+2] == 1 + + +def test_arcsin(): + x = nd.array([-1, -.707, 0, .707, 1]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.arcsin(x) + assert_almost_equal(y[0][0].asnumpy(), -np.pi/2, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), -np.pi/4, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), np.pi/4, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), np.pi/2, atol=1e-3) + + +def test_arccos(): + x = nd.array([-1, -.707, 0, .707, 1]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.arccos(x) + assert_almost_equal(y[0][0].asnumpy(), np.pi, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), 3*np.pi/4, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), np.pi/2, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), np.pi/4, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), 0, atol=1e-3) + + +def test_arctan(): + x = nd.array([-np.Inf, -1, 0, 1, np.Inf]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.arctan(x) + assert_almost_equal(y[0][0].asnumpy(), -np.pi/2, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), -np.pi/4, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), np.pi/4, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), np.pi/2, atol=1e-3) + + +def test_sin(): + x = nd.array([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.sin(x) + assert_almost_equal(y[0][0].asnumpy(), -1, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), -.707, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), .707, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), 1, atol=1e-3) + + +def test_cos(): + x = nd.array([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.cos(x) + assert_almost_equal(y[0][0].asnumpy(), 1, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), .707, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), -.707, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), -1, atol=1e-3) + + +def test_tan(): + x = nd.array([-np.pi/4, 0, np.pi/4]).reshape(1, 3) + x = nd.broadcast_to(x, (LARGE_X*10, x.shape[1])) + y = nd.tan(x) + assert y[0][0] == -1 + assert y[1][1] == 0 + assert y[-1][-1] == 1 + + +def test_radians(): + x = nd.array([0, 90, 180, 270, 360]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.radians(x) + assert_almost_equal(y[0][0].asnumpy(), 0, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), np.pi/2, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), np.pi, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), 3*np.pi/2, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), 2*np.pi, atol=1e-3) + + +def test_degrees(): + x = nd.array([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi]).reshape(1, 5) + x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) + y = nd.degrees(x) + assert_almost_equal(y[0][0].asnumpy(), 0, atol=1e-3) + assert_almost_equal(y[1][1].asnumpy(), 90, atol=1e-3) + assert_almost_equal(y[2][2].asnumpy(), 180, atol=1e-3) + assert_almost_equal(y[-2][3].asnumpy(), 270, atol=1e-3) + assert_almost_equal(y[-1][-1].asnumpy(), 360, atol=1e-3) + + +def test_L2Normalization(): + x = nd.ones((2, LARGE_X*2)) + x[0] = 3 + x[1] = 4 + # Channel Mode + z = x.reshape(1, 2, LARGE_X*2) + y = nd.L2Normalization(z, mode='channel') + assert y[0][0][0] == 0.6 + assert y[0][0][-1] == 0.6 + assert y[0][1][0] == 0.8 + assert y[0][1][-1] == 0.8 + # Instance Mode + z = x.T + y = nd.L2Normalization(z, mode='instance') + assert y[0][0] == 0.6 + assert y[0][1] == 0.8 + assert y[-1][0] == 0.6 + assert y[-1][1] == 0.8 + # Spatial Mode + z = z.reshape(1, 200000000, 2) + y = nd.L2Normalization(z, mode='spatial') + assert y[0][0][0] == 0.6 + assert y[0][0][1] == 0.8 + assert y[0][-1][0] == 0.6 + assert y[0][-1][1] == 0.8 + + +def test_instance_norm(): + dtype = np.float32 + forward_check_eps = 1E-3 + axis = -1 + eps = 1E-5 + in_shape = (LARGE_X, 1, SMALL_Y) + ctx = mx.cpu() + + def npy_instance_norm(data, gamma, beta, axis, eps=1E-5): + if axis < 0: + axis += data.ndim + broadcast_shape = [1 for _ in range(data.ndim)] + broadcast_shape[axis] = data.shape[axis] + mean = data.mean(axis=axis, keepdims=True).astype(dtype) + var = data.var(axis=axis, keepdims=True).astype(dtype) + std = np.sqrt(var + dtype(eps)).astype(dtype) + out = gamma * (data - mean) / std + \ + beta + return out + data = np.random.normal(0, 1, in_shape).astype(dtype) + gamma = np.random.normal(0, 1, (1,)).astype(dtype) + beta = np.random.normal(0, 1, (1,)).astype(dtype) + data_s = mx.symbol.Variable('data') + gamma_s = mx.symbol.Variable('gamma') + beta_s = mx.symbol.Variable('beta') + out_s = mx.symbol.InstanceNorm(data=data_s, gamma=gamma_s, beta=beta_s, + eps=eps) + exe = out_s.simple_bind(ctx, data=in_shape) + exe.arg_dict['data'][:] = data + exe.arg_dict['gamma'][:] = gamma + exe.arg_dict['beta'][:] = beta + out_nd = exe.forward()[0] + out = npy_instance_norm(data, gamma, beta, axis, eps) + assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, + forward_check_eps) + + if __name__ == '__main__': import nose nose.runmodule() From 843b407f9554b5af0441bc52d130dd02d669b854 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Tue, 15 Oct 2019 21:18:43 +0000 Subject: [PATCH 2/3] adding large tensor support for dropout operator --- src/operator/nn/dropout-inl.h | 12 ++++++------ src/operator/tensor/elemwise_binary_broadcast_op.h | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index eda9051fd0a2..6387dff96eb7 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -182,10 +182,10 @@ class DropoutOp { * \param input_data Input data to perform the dropout on * \param pkeep Dropout rate (keep when the generated random number is less than this value) */ - MSHADOW_XINLINE static void Map(int id, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, - const int N, - const int step, + const index_t N, + const index_t step, DType *dropout_out, DType *mask_out, const DType *input_data, @@ -199,10 +199,10 @@ class DropoutOp { }; struct BernoulliKernel { /*! \brief Bernoulli kernel for generating mask */ - MSHADOW_XINLINE static void Map(int id, + MSHADOW_XINLINE static void Map(index_t id, RandGenerator gen, - const int N, - const int step, + const index_t N, + const index_t step, DType *mask_out, const real_t pkeep) { RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, { diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 6a612e6f1cd5..3d3bcfacbd05 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -151,9 +151,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet: *new_oshape = mxnet::TShape(odim, 1); int bl = oshape.ndim() - lshape.ndim(); int br = oshape.ndim() - rshape.ndim(); - int j = 0, lprod = 1, rprod = 1, oprod = 1; + int j = 0; + index_t lprod = 1, rprod = 1, oprod = 1; for (int i = 0; i < oshape.ndim(); ++i) { - int l = 1, r = 1, o = oshape[i]; + index_t l = 1, r = 1, o = oshape[i]; if (i >= bl) l = lshape[i-bl]; if (i >= br) r = rshape[i-br]; if ((lprod != rprod || l != r) && From 1f8686e25ffce859499ea6cbc1b5aecaa047687c Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 16 Oct 2019 23:32:26 +0000 Subject: [PATCH 3/3] code refactor and added new comments --- tests/nightly/test_large_array.py | 272 +++++++++++++++--------------- 1 file changed, 134 insertions(+), 138 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index e04edbb46dd1..7ab1d025b75b 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1227,156 +1227,150 @@ def test_repeat(): def create_input_for_rounding_ops(): + # Creates an vector with values (-LARGE_X/2 .... -2, -1, 0, 1, 2, .... , LARGE_X/2-1) + # then divides each element by 2 i.e (-LARGE_X/4 .... -1, -0.5, 0, 0.5, 1, .... , LARGE_X/4-1) + # and finally broadcasts to inp = nd.arange(-LARGE_X//2, LARGE_X//2, dtype=np.float64).reshape(1, LARGE_X) inp = inp/2 inp = nd.broadcast_to(inp, (SMALL_Y, LARGE_X)) return inp -def test_ceil(): - x = create_input_for_rounding_ops() - y = nd.ceil(x) - assert y[1][LARGE_X//2-2] == -1 - assert y[1][LARGE_X//2-1] == 0 - assert y[1][LARGE_X//2] == 0 - assert y[1][LARGE_X//2+1] == 1 - assert y[1][LARGE_X//2+2] == 1 - - -def test_fix(): - x = create_input_for_rounding_ops() - y = nd.fix(x) - assert y[1][LARGE_X//2-2] == -1 - assert y[1][LARGE_X//2-1] == 0 - assert y[1][LARGE_X//2] == 0 - assert y[1][LARGE_X//2+1] == 0 - assert y[1][LARGE_X//2+2] == 1 +def assert_correctness_of_rounding_ops(output, mid, expected_vals): + # checks verifies 5 values at the middle positions of the input vector + # i.e mid-2, mid-1, mid, mid+1, mid+2 + output_idx_to_inspect = [mid-2, mid-1, mid, mid+1, mid+2] + for i in range(len(output_idx_to_inspect)): + assert output[1][output_idx_to_inspect[i]] == expected_vals[i] -def test_floor(): +# TODO(access2rohit): merge similar tests in large vector and array into one file. +def test_rounding_ops(): x = create_input_for_rounding_ops() - y = nd.floor(x) - assert y[1][LARGE_X//2-2] == -1 - assert y[1][LARGE_X//2-1] == -1 - assert y[1][LARGE_X//2] == 0 - assert y[1][LARGE_X//2+1] == 0 - assert y[1][LARGE_X//2+2] == 1 - -def test_rint(): - x = create_input_for_rounding_ops() - y = nd.rint(x) - assert y[1][LARGE_X//2-2] == -1 - assert y[1][LARGE_X//2-1] == -1 - assert y[1][LARGE_X//2] == 0 - assert y[1][LARGE_X//2+1] == 0 - assert y[1][LARGE_X//2+2] == 1 - - -def test_round(): - x = create_input_for_rounding_ops() - y = nd.round(x) - assert y[1][LARGE_X//2-2] == -1 - assert y[1][LARGE_X//2-1] == -1 - assert y[1][LARGE_X//2] == 0 - assert y[1][LARGE_X//2+1] == 1 - assert y[1][LARGE_X//2+2] == 1 + def check_ceil(): + y = nd.ceil(x) + # expected ouput for middle 5 values after applying ceil() + expected_output = [-1, 0, 0, 1, 1] + assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output) + + def check_fix(): + y = nd.fix(x) + # expected ouput for middle 5 values after applying fix() + expected_output = [-1, 0, 0, 0, 1] + assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output) + + def check_floor(): + y = nd.floor(x) + # expected ouput for middle 5 values after applying floor() + expected_output = [-1, -1, 0, 0, 1] + assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output) + + def check_rint(): + y = nd.rint(x) + # expected ouput for middle 5 values after applying rint() + expected_output = [-1, -1, 0, 0, 1] + assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output) + + def check_round(): + y = nd.round(x) + # expected ouput for middle 5 values after applying round() + expected_output = [-1, -1, 0, 1, 1] + assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output) + + def check_trunc(): + y = nd.trunc(x) + # expected ouput for middle 5 values after applying trunc() + expected_output = [-1, 0, 0, 0, 1] + assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output) + + check_ceil() + check_fix() + check_floor() + check_rint() + check_round() + check_trunc() + + +def create_input_for_trigonometric_ops(vals): + # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using tile operator + inp = nd.array(vals).reshape(1, 5) + inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10)) + return inp -def test_trunc(): - x = create_input_for_rounding_ops() - y = nd.trunc(x) - assert y[1][LARGE_X//2-2] == -1 - assert y[1][LARGE_X//2-1] == 0 - assert y[1][LARGE_X//2] == 0 - assert y[1][LARGE_X//2+1] == 0 - assert y[1][LARGE_X//2+2] == 1 - - -def test_arcsin(): - x = nd.array([-1, -.707, 0, .707, 1]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.arcsin(x) - assert_almost_equal(y[0][0].asnumpy(), -np.pi/2, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), -np.pi/4, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), np.pi/4, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), np.pi/2, atol=1e-3) - - -def test_arccos(): - x = nd.array([-1, -.707, 0, .707, 1]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.arccos(x) - assert_almost_equal(y[0][0].asnumpy(), np.pi, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), 3*np.pi/4, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), np.pi/2, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), np.pi/4, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), 0, atol=1e-3) - - -def test_arctan(): - x = nd.array([-np.Inf, -1, 0, 1, np.Inf]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.arctan(x) - assert_almost_equal(y[0][0].asnumpy(), -np.pi/2, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), -np.pi/4, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), np.pi/4, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), np.pi/2, atol=1e-3) - - -def test_sin(): - x = nd.array([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.sin(x) - assert_almost_equal(y[0][0].asnumpy(), -1, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), -.707, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), .707, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), 1, atol=1e-3) - - -def test_cos(): - x = nd.array([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.cos(x) - assert_almost_equal(y[0][0].asnumpy(), 1, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), .707, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), 0, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), -.707, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), -1, atol=1e-3) - - -def test_tan(): - x = nd.array([-np.pi/4, 0, np.pi/4]).reshape(1, 3) - x = nd.broadcast_to(x, (LARGE_X*10, x.shape[1])) - y = nd.tan(x) - assert y[0][0] == -1 - assert y[1][1] == 0 - assert y[-1][-1] == 1 - - -def test_radians(): - x = nd.array([0, 90, 180, 270, 360]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.radians(x) - assert_almost_equal(y[0][0].asnumpy(), 0, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), np.pi/2, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), np.pi, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), 3*np.pi/2, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), 2*np.pi, atol=1e-3) - - -def test_degrees(): - x = nd.array([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi]).reshape(1, 5) - x = nd.broadcast_to(x, (LARGE_X*10, SMALL_Y//10)) - y = nd.degrees(x) - assert_almost_equal(y[0][0].asnumpy(), 0, atol=1e-3) - assert_almost_equal(y[1][1].asnumpy(), 90, atol=1e-3) - assert_almost_equal(y[2][2].asnumpy(), 180, atol=1e-3) - assert_almost_equal(y[-2][3].asnumpy(), 270, atol=1e-3) - assert_almost_equal(y[-1][-1].asnumpy(), 360, atol=1e-3) +def assert_correctness_of_trigonometric_ops(output, expected_vals): + # checks verifies 5 values at positions(0, 1, -3, -2, -1) of the input vector + output_idx_to_inspect = [0, 1, -3, -2, -1] + for i in range(len(output_idx_to_inspect)): + assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= 1e-3 + + +def test_trigonometric_ops(): + def check_arcsin(): + x = create_input_for_trigonometric_ops([-1, -.707, 0, .707, 1]) + y = nd.arcsin(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying arcsin() + expected_output = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_arccos(): + x = create_input_for_trigonometric_ops([-1, -.707, 0, .707, 1]) + y = nd.arccos(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying arccos() + expected_output = [np.pi, 3*np.pi/4, np.pi/2, np.pi/4, 0] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_arctan(): + x = create_input_for_trigonometric_ops([-np.Inf, -1, 0, 1, np.Inf]) + y = nd.arctan(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying arctan() + expected_output = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_sin(): + x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]) + y = nd.sin(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying sin() + expected_output = [-1, -.707, 0, .707, 1] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_cos(): + x = create_input_for_trigonometric_ops([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi]) + y = nd.cos(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying cos() + expected_output = [1, .707, 0, -.707, -1] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_tan(): + x = create_input_for_trigonometric_ops([-np.pi/6, -np.pi/4, 0, np.pi/4, np.pi/6]) + y = nd.tan(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying tan() + expected_output = [-.577, -1, 0, 1, .577] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_radians(): + x = create_input_for_trigonometric_ops([0, 90, 180, 270, 360]) + y = nd.radians(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying radians() + expected_output = [0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi] + assert_correctness_of_trigonometric_ops(y, expected_output) + + def check_degrees(): + x = create_input_for_trigonometric_ops([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi]) + y = nd.degrees(x) + # expected ouput for indices=(0, 1, -3, -2, -1) after applying degrees() + expected_output = [0, 90, 180, 270, 360] + assert_correctness_of_trigonometric_ops(y, expected_output) + + check_arcsin() + check_arccos() + check_arctan() + check_sin() + check_cos() + check_tan() + check_radians() + check_degrees() def test_L2Normalization(): @@ -1414,6 +1408,7 @@ def test_instance_norm(): in_shape = (LARGE_X, 1, SMALL_Y) ctx = mx.cpu() + # Implementation of instance normalization using numpy def npy_instance_norm(data, gamma, beta, axis, eps=1E-5): if axis < 0: axis += data.ndim @@ -1438,6 +1433,7 @@ def npy_instance_norm(data, gamma, beta, axis, eps=1E-5): exe.arg_dict['gamma'][:] = gamma exe.arg_dict['beta'][:] = beta out_nd = exe.forward()[0] + # Calls implementation of instance norm in numpy and compares the output out = npy_instance_norm(data, gamma, beta, axis, eps) assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, forward_check_eps)