From 01b0f48101646dcb0c3f5ed0a73ccbf03838669b Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 8 Aug 2019 09:22:42 -0700 Subject: [PATCH 01/10] sequence_last, sequence_reverse, sequence_mask --- tests/nightly/test_large_array.py | 49 +++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 0df481a01987..101914b518b6 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -352,6 +352,55 @@ def test_topk(): assert l.sum() == np.sum(np.arange(0, SMALL_Y)) +def test_sequence_mask(): + # Sequence Mask input - [max_sequence_length, batch_size, other_feature_dims] + # test with input batch_size = 2 + a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) + + # test as identity operator + b = nd.SequenceMask(a) + assert b[-1][0][1] == a[-1][0][1] + assert b.shape == a.shape + + # test with default mask + b = nd.SequenceMask(a, sequence_length = nd.array([1,1]), use_sequence_length = True) + assert b[0][1][-1] == a[0][1][-1] #first sequence of each batch kept + assert b[-1][-1][-1] != a[-1][-1][-1] #rest sequences masked + assert b[-1][-1][-1] == 0 + + # test with mask value + b = nd.SequenceMask(a, sequence_length = nd.array([1,1]), use_sequence_length = True, value = -1) + assert b[-1][-1][-1] == -1 + +def test_sequence_reverse(): + a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) + # test as reverse operator + b = nd.SequenceReverse(a) + assert b[-1][0][0] == a[0][0][0] + assert b.shape == a.shape + + # test with sequence length + b = nd.SequenceReverse(a, sequence_length=[2,3]) + assert b[1][0][0] == a[0][0][0] #check if reversed + assert b[-1][0][0] == a[-1][0][0] #check if intact + assert b.shape == a.shape + + +def test_sequence_last(): + a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) + + # test if returns last sequence + b = nd.SequenceLast(a) + assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor + assert b.shape == (2, SMALL_Y) + + # test with sequence length + # parameter sequence_length - NDArray with shape (batch_size) + # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 + b = nd.SequenceLast(a, sequence_length = mx.nd.array([2,3]), use_sequence_length = True) + assert b[0][-1] == a[1][0][-1] #check if it takes 2nd sequence from the first batch + + if __name__ == '__main__': import nose nose.runmodule() From c4062e93dd0da3f7491e089dfa5cc61c266ff608 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 8 Aug 2019 15:21:05 -0700 Subject: [PATCH 02/10] working softmax_cross_entropy --- tests/nightly/test_large_array.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 101914b518b6..5f1f91b60790 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -401,6 +401,25 @@ def test_sequence_last(): assert b[0][-1] == a[1][0][-1] #check if it takes 2nd sequence from the first batch +def test_softmax_cross_entropy(): + # dtype of input data, mxnet cross entropy set explicitly to float64 + # numpy implicitly takes care of double precision + batch_size = SMALL_Y + num_labels = LARGE_X + input_data = mx.nd.ones((batch_size, num_labels),dtype="float64") + input_label = mx.nd.zeros((batch_size,),dtype="float64") + + true_softmax = np.full((batch_size, num_labels), (1 / num_labels)) + # use 1/batch_size when softmax axis=0 + # here 1/num_labels since softmax_cross_entropy uses default axis + # by default axis=1 + np_one_hot_label = np.zeros((batch_size,num_labels)) + np_one_hot_label[:,0] = 1 + + true_softmax_cross_entropy = np.sum(-np.log(true_softmax) * np_one_hot_label) + mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data,input_label, dtype="float64") + assert_almost_equal(mx_softmax_cross_entropy.asnumpy(), true_softmax_cross_entropy, rtol=1e-3, atol=1e-5) + if __name__ == '__main__': import nose nose.runmodule() From 6bd4f689844abbcedcfa435906def9dcac700c70 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 9 Aug 2019 11:16:30 -0700 Subject: [PATCH 03/10] fix linting, add index_copy --- tests/nightly/test_large_array.py | 76 +++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 5f1f91b60790..33017f770cbe 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -217,9 +217,11 @@ def test_pick(): def test_depthtospace(): def numpy_depth_to_space(x, blocksize): b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] - tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) + tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, + w]) tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) - y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) + y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, + w * blocksize]) return y shape_inp = (LARGE_X, 8, 4, 2) @@ -233,9 +235,11 @@ def numpy_depth_to_space(x, blocksize): def test_spacetodepth(): def numpy_space_to_depth(x, blocksize): b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] - tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]) + tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, + blocksize]) tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) - y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize]) + y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, + w // blocksize]) return y shape_inp = (LARGE_X, 2, 8, 4) @@ -245,6 +249,7 @@ def numpy_space_to_depth(x, blocksize): output = mx.nd.space_to_depth(data, 2) assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + @with_seed() def test_diag(): a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32) @@ -276,7 +281,8 @@ def test_ravel_multi_index(): x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y) x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y) indices_2d = [[x1, x2, x3], [y1, y2, y3]] - idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, SMALL_Y)) + idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), + shape=(LARGE_X, SMALL_Y)) idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y)) assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3 @@ -288,7 +294,8 @@ def test_unravel_index(): x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y) original_2d_indices = [[x1, x2, x3], [y1, y2, y3]] idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y)) - indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, SMALL_Y)) + indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), + shape=(LARGE_X, SMALL_Y)) assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() @@ -345,7 +352,8 @@ def test_topk(): b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) k = nd.topk(b, k=10, axis=0, dtype=np.int64) assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y - ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) + ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", + is_ascend=False) assert np.all(ind == val) b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value") @@ -353,7 +361,7 @@ def test_topk(): def test_sequence_mask(): - # Sequence Mask input - [max_sequence_length, batch_size, other_feature_dims] + # Sequence Mask input [max_sequence_length, batch_size, other_feature_dims] # test with input batch_size = 2 a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) @@ -363,15 +371,18 @@ def test_sequence_mask(): assert b.shape == a.shape # test with default mask - b = nd.SequenceMask(a, sequence_length = nd.array([1,1]), use_sequence_length = True) - assert b[0][1][-1] == a[0][1][-1] #first sequence of each batch kept - assert b[-1][-1][-1] != a[-1][-1][-1] #rest sequences masked + b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]), + use_sequence_length=True) + assert b[0][1][-1] == a[0][1][-1] # first sequence of each batch kept + assert b[-1][-1][-1] != a[-1][-1][-1] # rest sequences masked assert b[-1][-1][-1] == 0 # test with mask value - b = nd.SequenceMask(a, sequence_length = nd.array([1,1]), use_sequence_length = True, value = -1) + b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]), + use_sequence_length=True, value=-1) assert b[-1][-1][-1] == -1 + def test_sequence_reverse(): a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) # test as reverse operator @@ -380,9 +391,9 @@ def test_sequence_reverse(): assert b.shape == a.shape # test with sequence length - b = nd.SequenceReverse(a, sequence_length=[2,3]) - assert b[1][0][0] == a[0][0][0] #check if reversed - assert b[-1][0][0] == a[-1][0][0] #check if intact + b = nd.SequenceReverse(a, sequence_length=[2, 3]) + assert b[1][0][0] == a[0][0][0] # check if reversed + assert b[-1][0][0] == a[-1][0][0] # check if intact assert b.shape == a.shape @@ -391,14 +402,16 @@ def test_sequence_last(): # test if returns last sequence b = nd.SequenceLast(a) - assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor + assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor assert b.shape == (2, SMALL_Y) # test with sequence length # parameter sequence_length - NDArray with shape (batch_size) # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 - b = nd.SequenceLast(a, sequence_length = mx.nd.array([2,3]), use_sequence_length = True) - assert b[0][-1] == a[1][0][-1] #check if it takes 2nd sequence from the first batch + b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]), + use_sequence_length=True) + # check if it takes 2nd sequence from the first batch + assert b[0][-1] == a[1][0][-1] def test_softmax_cross_entropy(): @@ -406,19 +419,32 @@ def test_softmax_cross_entropy(): # numpy implicitly takes care of double precision batch_size = SMALL_Y num_labels = LARGE_X - input_data = mx.nd.ones((batch_size, num_labels),dtype="float64") - input_label = mx.nd.zeros((batch_size,),dtype="float64") + input_data = mx.nd.ones((batch_size, num_labels), dtype="float64") + input_label = mx.nd.zeros((batch_size,), dtype="float64") true_softmax = np.full((batch_size, num_labels), (1 / num_labels)) # use 1/batch_size when softmax axis=0 # here 1/num_labels since softmax_cross_entropy uses default axis # by default axis=1 - np_one_hot_label = np.zeros((batch_size,num_labels)) - np_one_hot_label[:,0] = 1 + np_one_hot_label = np.zeros((batch_size, num_labels)) + np_one_hot_label[:, 0] = 1 + + true_softmax_cross_entropy = np.sum(-np.log(true_softmax) * + np_one_hot_label) + mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data, + input_label, + dtype="float64") + assert_almost_equal(mx_softmax_cross_entropy.asnumpy(), + true_softmax_cross_entropy, rtol=1e-3, atol=1e-5) + + +def test_index_copy(): + x = mx.nd.zeros((LARGE_X, SMALL_Y)) + t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y)) + index = mx.nd.array([LARGE_X - 1]) - true_softmax_cross_entropy = np.sum(-np.log(true_softmax) * np_one_hot_label) - mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data,input_label, dtype="float64") - assert_almost_equal(mx_softmax_cross_entropy.asnumpy(), true_softmax_cross_entropy, rtol=1e-3, atol=1e-5) + x = mx.nd.contrib.index_copy(x, index, t) + assert x[-1][-1] == t[0][-1] if __name__ == '__main__': import nose From 3ba249f1c5d702f639d251115c5cad756ea65256 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 9 Aug 2019 14:55:27 -0700 Subject: [PATCH 04/10] add softmax output --- tests/nightly/test_large_array.py | 33 ++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 33017f770cbe..5171de1a9039 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -17,7 +17,7 @@ import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -446,6 +446,37 @@ def test_index_copy(): x = mx.nd.contrib.index_copy(x, index, t) assert x[-1][-1] == t[0][-1] + +# def test_dropout(): +# a = mx.nd.ones((LARGE_X, SMALL_Y)) +# # test dropout ratio +# x = nx.sym.var('data') + + +def testSoftmaxOutput(): + x = mx.sym.Variable('x') + label = mx.sym.Variable('label') + x_nd = mx.nd.ones((LARGE_X, SMALL_Y)) + grad_x = mx.nd.zeros((LARGE_X, SMALL_Y)) + label_nd = mx.nd.ones((LARGE_X)) + + sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, + use_ignore=False) + ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, + args_grad={'x': grad_x}) + + ex.forward(is_train=True) + softmax_out = ex.outputs[0][0].asnumpy() + expected_softmax_out = (1/SMALL_Y)*mx.nd.ones((SMALL_Y)).asnumpy() + assert np.isclose(softmax_out, expected_softmax_out).all() + + ex.backward(is_train=True) + grad_out = ex.grad_arrays[0][0].asnumpy() + k = int(label_nd[0].asscalar()) + expected_grad_out = np.zeros((SMALL_Y,)) + expected_grad_out[k] = -1 + assert np.isclose(grad_out - softmax_out, expected_grad_out).all() + if __name__ == '__main__': import nose nose.runmodule() From d791606ff18f1f4461e42edf52150843e4eb6e08 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Fri, 9 Aug 2019 16:47:54 -0700 Subject: [PATCH 05/10] add leaky relu --- tests/nightly/test_large_array.py | 39 +++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 5171de1a9039..194f0ff2592b 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -447,7 +447,7 @@ def test_index_copy(): assert x[-1][-1] == t[0][-1] -# def test_dropout(): +# def test_rnn_and_dropouts(): # a = mx.nd.ones((LARGE_X, SMALL_Y)) # # test dropout ratio # x = nx.sym.var('data') @@ -460,7 +460,7 @@ def testSoftmaxOutput(): grad_x = mx.nd.zeros((LARGE_X, SMALL_Y)) label_nd = mx.nd.ones((LARGE_X)) - sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, + sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, use_ignore=False) ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, args_grad={'x': grad_x}) @@ -477,6 +477,41 @@ def testSoftmaxOutput(): expected_grad_out[k] = -1 assert np.isclose(grad_out - softmax_out, expected_grad_out).all() + +def test_leaky_relu(): + a = -1*mx.nd.ones((LARGE_X, SMALL_Y)) + + def test_leaky(): + res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3) + assert res[-1][-1].asnumpy() == 0.3*a[-1][-1].asnumpy() + + def test_elu(): + res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3) + assert res[-1][-1].asnumpy() == 0.3*(np.exp(a[-1][-1].asnumpy())-1) + + def test_selu(): + lam = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + res = mx.nd.LeakyReLU(a, act_type="selu") + assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)) + + def test_prelu(): + res = mx.nd.LeakyReLU(a, act_type="prelu", gamma=mx.nd.array([0.3])) + assert res[-1][-1].asnumpy() == 0.3 * a[-1][-1].asnumpy() + + def test_rrelu(): + lower = 0.125 + upper = 0.333999991 + res = mx.nd.LeakyReLU(a, act_type="rrelu") + assert res[-1][-1].asnumpy() == (lower + upper) / 2 * a[-1][-1].asnumpy() + + test_leaky() + test_elu() + test_selu() + test_prelu() + test_rrelu() + + if __name__ == '__main__': import nose nose.runmodule() From 00c7e352c5e2adb1b664fc34595567573a61680c Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 12 Aug 2019 09:29:56 -0700 Subject: [PATCH 06/10] add pooling --- tests/nightly/test_large_array.py | 44 ++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 194f0ff2592b..a432988fef36 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -495,9 +495,12 @@ def test_selu(): res = mx.nd.LeakyReLU(a, act_type="selu") assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)) - def test_prelu(): - res = mx.nd.LeakyReLU(a, act_type="prelu", gamma=mx.nd.array([0.3])) - assert res[-1][-1].asnumpy() == 0.3 * a[-1][-1].asnumpy() + # def test_prelu(): + # res = mx.nd.LeakyReLU(a, act_type="prelu", gamma=mx.nd.array([0.3])) + # assert res[-1][-1].asnumpy() == 0.3 * a[-1][-1].asnumpy() + # fails with large tensor shape + # all values from [0][0] till [14100654][3] have correct -0.3 + # all values from [14100654][4] till [-1][-1] have 0. def test_rrelu(): lower = 0.125 @@ -508,10 +511,43 @@ def test_rrelu(): test_leaky() test_elu() test_selu() - test_prelu() + # test_prelu() test_rrelu() +def test_pooling(): + a = mx.nd.ones((MEDIUM_X, MEDIUM_X, SMALL_Y, SMALL_Y)) + + def test_avg_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg') + assert res[-1][-1][-1][-1] == 1.0000001 + assert res.shape == SMALL_Y - 5 + 1 + + def test_max_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max') + assert res[-1][-1][-1][-1] == 1. + assert res.shape == SMALL_Y - 5 + 1 + + def test_sum_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum') + assert res[-1][-1][-1][-1] == 25 + assert res.shape == SMALL_Y - 5 + 1 + + def test_lp_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2) + assert res[-1][-1][-1][-1] == 5. + assert res.shape == SMALL_Y - 5 + 1 + + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1) + assert res[-1][-1][-1][-1] == 25. + assert res.shape == SMALL_Y - 5 + 1 + + test_avg_pooling() + test_max_pooling() + test_sum_pooling() + test_lp_pooling() + + if __name__ == '__main__': import nose nose.runmodule() From 1358a8f55e4395111235e4e04a2f0b66a9974dfc Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Mon, 12 Aug 2019 12:02:10 -0700 Subject: [PATCH 07/10] add layernorm --- tests/nightly/test_large_array.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index a432988fef36..400317aab44a 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -20,6 +20,7 @@ from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context from mxnet import gluon, nd from tests.python.unittest.common import with_seed +from tests.python.unittest.test_operator import check_layer_normalization # dimension constants MEDIUM_X = 10000 @@ -548,6 +549,15 @@ def test_lp_pooling(): test_lp_pooling() +def test_layer_norm(): + for forward_check_eps, backward_check_eps in zip([1E-2, 1E-5], [1E-2, 1E-5]): + check_layer_normalization(in_shape=(LARGE_X, SMALL_Y), + forward_check_eps=forward_check_eps, + backward_check_eps=backward_check_eps, + npy_grad_check=True, + finite_grad_check=True) + + if __name__ == '__main__': import nose nose.runmodule() From 8297d1fa2ec54601b3424492ac1aa546bba9e980 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 13 Aug 2019 09:53:36 -0700 Subject: [PATCH 08/10] add dropout, activation, batchnorm and update layernorm --- tests/nightly/test_large_array.py | 124 ++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 7 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 400317aab44a..9a862001c5ac 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -17,10 +17,10 @@ import numpy as np import mxnet as mx +import math from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context from mxnet import gluon, nd from tests.python.unittest.common import with_seed -from tests.python.unittest.test_operator import check_layer_normalization # dimension constants MEDIUM_X = 10000 @@ -550,12 +550,122 @@ def test_lp_pooling(): def test_layer_norm(): - for forward_check_eps, backward_check_eps in zip([1E-2, 1E-5], [1E-2, 1E-5]): - check_layer_normalization(in_shape=(LARGE_X, SMALL_Y), - forward_check_eps=forward_check_eps, - backward_check_eps=backward_check_eps, - npy_grad_check=True, - finite_grad_check=True) + dtype = np.float32 + forward_check_eps = 1E-3 + axis = 1 + eps = 1E-5 + in_shape = (LARGE_X, SMALL_Y) + ctx = mx.cpu() + + def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): + if axis < 0: + axis += data.ndim + broadcast_shape = [1 for _ in range(data.ndim)] + broadcast_shape[axis] = data.shape[axis] + mean = data.mean(axis=axis, keepdims=True).astype(dtype) + var = data.var(axis=axis, keepdims=True).astype(dtype) + std = np.sqrt(var + dtype(eps)).astype(dtype) + out = np.reshape(gamma, broadcast_shape) * (data - mean) / std + \ + np.reshape(beta, broadcast_shape) + return out + data = np.random.normal(0, 1, in_shape).astype(dtype) + gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) + beta = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) + data_s = mx.symbol.Variable('data') + gamma_s = mx.symbol.Variable('gamma') + beta_s = mx.symbol.Variable('beta') + out_s = mx.symbol.LayerNorm(data=data_s, gamma=gamma_s, beta=beta_s, + axis=axis, eps=eps) + exe = out_s.simple_bind(ctx, data=in_shape) + exe.arg_dict['data'][:] = data + exe.arg_dict['gamma'][:] = gamma + exe.arg_dict['beta'][:] = beta + out_nd = exe.forward()[0] + out = npy_layer_norm(data, gamma, beta, axis, eps) + assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, + forward_check_eps) + # for forward_check_eps, backward_check_eps in zip([1E-2, 1E-5], [1E-2, 1E-5]): + # check_layer_normalization(in_shape=(LARGE_X, SMALL_Y), + # forward_check_eps=forward_check_eps, + # backward_check_eps=backward_check_eps, + # npy_grad_check=True, + # finite_grad_check=True) + + +# TODO: correctness of dropout +# currently only test for dropout to work +# since testing for correctness involves flakiness issue #14288 +def test_dropout(): + shape = (10, 10) + x = mx.sym.var('data') + y = mx.sym.Dropout(x, p=1, cudnn_off=True) + exe = y.simple_bind(ctx=default_context(), data=shape) + exe.arg_arrays[0][:] = 1 + out = exe.forward(is_train=True) + out[0].wait_to_read() + + +def test_activation(): + a = mx.nd.ones((LARGE_X, SMALL_Y)) + test_x = -2 + a[-1, -1] = test_x + + # Hyperbolic tangent (tanh) + # y = (exp(x)-exp(-x))/(exp(x)+exp(-x)) + a = mx.nd.Activation(a, act_type="tanh") + tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2)) + assert a[-1][-1] == tanh_x + + # Recitified Linear Unit (relu) + # y = max(x,0) + a = mx.nd.Activation(a, act_type="relu") + assert a[-1][-1] == 0 + + # Sigmoid + # y = x/(1+abs(x)) + a = mx.nd.Activation(a, act_type="sigmoid") + sigmoid_x = 1/(1+math.exp(-test_x)) + assert a[-1][-1] == sigmoid_x + + # Soft Sign + # y = 1/(1+exp(-x)) + a = mx.nd.Activation(a, act_type="softsign") + softsign_x = test_x/(1+abs(test_x)) + assert a[-1][-1] == softsign_x + + +# TODO: correctness of batchnorm +# in future, we could test if mean, var of output +# matches target output's mean, var +def test_batchnorm(): + # output_mean_var=True # useful for correctness check + # epsilon = 0.0010000000474974513 # default + shape = (LARGE_X, SMALL_Y) + axis = 1 # default + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + + nch = shape[axis] + data = mx.nd.ones(shape=shape) + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_beta = mx.nd.random.uniform(shape=(nch,)) + bn_running_mean = mx.nd.zeros(nch) + bn_running_var = mx.nd.ones(nch) + + output = mx.nd.BatchNorm(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var) + output.wait_to_read() + # flaky + # data_mean = data.mean(axis=axis, exclude=True, keepdims=True) + # data_var = (data - data_mean).square().mean(axis=axis, + # exclude=True, + # keepdims=True) + # target_output = (data - data_mean) / \ + # (data_var + epsilon).sqrt() * \ + # bn_gamma.reshape(expand_shape) + \ + # bn_beta.reshape(expand_shape) + # assert_almost_equal(output.asnumpy(), target_output.asnumpy(), + # atol=1e-2, rtol=1e-2) if __name__ == '__main__': From 2b2c9191e0fb145f5ae2261912014c2345cfbe00 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 14 Aug 2019 09:12:46 -0700 Subject: [PATCH 09/10] address comments to remove some comments --- tests/nightly/test_large_array.py | 35 +------------------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 61ed4cd48637..f272da7a0676 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -447,12 +447,6 @@ def test_index_copy(): assert x[-1][-1] == t[0][-1] -# def test_rnn_and_dropouts(): -# a = mx.nd.ones((LARGE_X, SMALL_Y)) -# # test dropout ratio -# x = nx.sym.var('data') - - def testSoftmaxOutput(): x = mx.sym.Variable('x') label = mx.sym.Variable('label') @@ -478,6 +472,7 @@ def testSoftmaxOutput(): assert np.isclose(grad_out - softmax_out, expected_grad_out).all() +# TODO: correctness of prelu (currently flaky) def test_leaky_relu(): a = -1*mx.nd.ones((LARGE_X, SMALL_Y)) @@ -495,13 +490,6 @@ def test_selu(): res = mx.nd.LeakyReLU(a, act_type="selu") assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)) - # def test_prelu(): - # res = mx.nd.LeakyReLU(a, act_type="prelu", gamma=mx.nd.array([0.3])) - # assert res[-1][-1].asnumpy() == 0.3 * a[-1][-1].asnumpy() - # fails with large tensor shape - # all values from [0][0] till [14100654][3] have correct -0.3 - # all values from [14100654][4] till [-1][-1] have 0. - def test_rrelu(): lower = 0.125 upper = 0.333999991 @@ -511,7 +499,6 @@ def test_rrelu(): test_leaky() test_elu() test_selu() - # test_prelu() test_rrelu() @@ -583,13 +570,6 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): out = npy_layer_norm(data, gamma, beta, axis, eps) assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, forward_check_eps) - # for forward_check_eps, backward_check_eps in zip([1E-2, 1E-5], [1E-2, 1E-5]): - # check_layer_normalization(in_shape=(LARGE_X, SMALL_Y), - # forward_check_eps=forward_check_eps, - # backward_check_eps=backward_check_eps, - # npy_grad_check=True, - # finite_grad_check=True) - # TODO: correctness of dropout # currently only test for dropout to work @@ -637,8 +617,6 @@ def test_activation(): # in future, we could test if mean, var of output # matches target output's mean, var def test_batchnorm(): - # output_mean_var=True # useful for correctness check - # epsilon = 0.0010000000474974513 # default shape = (LARGE_X, SMALL_Y) axis = 1 # default expand_shape = [1] * len(shape) @@ -654,17 +632,6 @@ def test_batchnorm(): output = mx.nd.BatchNorm(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var) output.wait_to_read() - # flaky - # data_mean = data.mean(axis=axis, exclude=True, keepdims=True) - # data_var = (data - data_mean).square().mean(axis=axis, - # exclude=True, - # keepdims=True) - # target_output = (data - data_mean) / \ - # (data_var + epsilon).sqrt() * \ - # bn_gamma.reshape(expand_shape) + \ - # bn_beta.reshape(expand_shape) - # assert_almost_equal(output.asnumpy(), target_output.asnumpy(), - # atol=1e-2, rtol=1e-2) if __name__ == '__main__': From b6c418cb8e95813600f11849576116facc2e1456 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 14 Aug 2019 09:18:49 -0700 Subject: [PATCH 10/10] handling imports --- tests/nightly/test_large_array.py | 3 ++- tests/nightly/test_large_vector.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 54fbe72f3658..585748ee59b1 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. +import math import numpy as np import mxnet as mx -import math + from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context from mxnet import gluon, nd from tests.python.unittest.common import with_seed diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 8c030f5bc20e..3a66500957e0 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -17,6 +17,7 @@ import numpy as np import mxnet as mx + from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d from mxnet import gluon, nd from tests.python.unittest.common import with_seed