From 0bc6f5b9b907649c68d239b40f238e7f597a2bc1 Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Wed, 14 Aug 2019 13:10:58 -0700 Subject: [PATCH] Add Large Tensor Support for Sequence, NN Ops (#15807) * sequence_last, sequence_reverse, sequence_mask * working softmax_cross_entropy * fix linting, add index_copy * add softmax output * add leaky relu * add pooling * add layernorm * add dropout, activation, batchnorm and update layernorm * address comments to remove some comments * handling imports --- tests/nightly/test_large_array.py | 300 ++++++++++++++++++++++++++++- tests/nightly/test_large_vector.py | 1 + 2 files changed, 293 insertions(+), 8 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index a0dfea6ad4bf..585748ee59b1 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. +import math import numpy as np import mxnet as mx -from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d + +from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context from mxnet import gluon, nd from tests.python.unittest.common import with_seed @@ -299,9 +301,11 @@ def test_pick(): def test_depthtospace(): def numpy_depth_to_space(x, blocksize): b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] - tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) + tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, + w]) tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) - y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) + y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, + w * blocksize]) return y shape_inp = (LARGE_X, 8, 4, 2) @@ -315,9 +319,11 @@ def numpy_depth_to_space(x, blocksize): def test_spacetodepth(): def numpy_space_to_depth(x, blocksize): b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] - tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]) + tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, + blocksize]) tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) - y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize]) + y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, + w // blocksize]) return y shape_inp = (LARGE_X, 2, 8, 4) @@ -327,6 +333,7 @@ def numpy_space_to_depth(x, blocksize): output = mx.nd.space_to_depth(data, 2) assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3) + @with_seed() def test_diag(): a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32) @@ -358,7 +365,8 @@ def test_ravel_multi_index(): x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y) x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y) indices_2d = [[x1, x2, x3], [y1, y2, y3]] - idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, SMALL_Y)) + idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), + shape=(LARGE_X, SMALL_Y)) idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y)) assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3 @@ -370,7 +378,8 @@ def test_unravel_index(): x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y) original_2d_indices = [[x1, x2, x3], [y1, y2, y3]] idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y)) - indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, SMALL_Y)) + indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), + shape=(LARGE_X, SMALL_Y)) assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all() @@ -427,13 +436,288 @@ def test_topk(): b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) k = nd.topk(b, k=10, axis=0, dtype=np.int64) assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y - ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False) + ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", + is_ascend=False) assert np.all(ind == val) b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X) l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value") assert l.sum() == np.sum(np.arange(0, SMALL_Y)) +def test_sequence_mask(): + # Sequence Mask input [max_sequence_length, batch_size, other_feature_dims] + # test with input batch_size = 2 + a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) + + # test as identity operator + b = nd.SequenceMask(a) + assert b[-1][0][1] == a[-1][0][1] + assert b.shape == a.shape + + # test with default mask + b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]), + use_sequence_length=True) + assert b[0][1][-1] == a[0][1][-1] # first sequence of each batch kept + assert b[-1][-1][-1] != a[-1][-1][-1] # rest sequences masked + assert b[-1][-1][-1] == 0 + + # test with mask value + b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]), + use_sequence_length=True, value=-1) + assert b[-1][-1][-1] == -1 + + +def test_sequence_reverse(): + a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) + # test as reverse operator + b = nd.SequenceReverse(a) + assert b[-1][0][0] == a[0][0][0] + assert b.shape == a.shape + + # test with sequence length + b = nd.SequenceReverse(a, sequence_length=[2, 3]) + assert b[1][0][0] == a[0][0][0] # check if reversed + assert b[-1][0][0] == a[-1][0][0] # check if intact + assert b.shape == a.shape + + +def test_sequence_last(): + a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y) + + # test if returns last sequence + b = nd.SequenceLast(a) + assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor + assert b.shape == (2, SMALL_Y) + + # test with sequence length + # parameter sequence_length - NDArray with shape (batch_size) + # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2 + b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]), + use_sequence_length=True) + # check if it takes 2nd sequence from the first batch + assert b[0][-1] == a[1][0][-1] + + +def test_softmax_cross_entropy(): + # dtype of input data, mxnet cross entropy set explicitly to float64 + # numpy implicitly takes care of double precision + batch_size = SMALL_Y + num_labels = LARGE_X + input_data = mx.nd.ones((batch_size, num_labels), dtype="float64") + input_label = mx.nd.zeros((batch_size,), dtype="float64") + + true_softmax = np.full((batch_size, num_labels), (1 / num_labels)) + # use 1/batch_size when softmax axis=0 + # here 1/num_labels since softmax_cross_entropy uses default axis + # by default axis=1 + np_one_hot_label = np.zeros((batch_size, num_labels)) + np_one_hot_label[:, 0] = 1 + + true_softmax_cross_entropy = np.sum(-np.log(true_softmax) * + np_one_hot_label) + mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data, + input_label, + dtype="float64") + assert_almost_equal(mx_softmax_cross_entropy.asnumpy(), + true_softmax_cross_entropy, rtol=1e-3, atol=1e-5) + + +def test_index_copy(): + x = mx.nd.zeros((LARGE_X, SMALL_Y)) + t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y)) + index = mx.nd.array([LARGE_X - 1]) + + x = mx.nd.contrib.index_copy(x, index, t) + assert x[-1][-1] == t[0][-1] + + +def testSoftmaxOutput(): + x = mx.sym.Variable('x') + label = mx.sym.Variable('label') + x_nd = mx.nd.ones((LARGE_X, SMALL_Y)) + grad_x = mx.nd.zeros((LARGE_X, SMALL_Y)) + label_nd = mx.nd.ones((LARGE_X)) + + sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, + use_ignore=False) + ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, + args_grad={'x': grad_x}) + + ex.forward(is_train=True) + softmax_out = ex.outputs[0][0].asnumpy() + expected_softmax_out = (1/SMALL_Y)*mx.nd.ones((SMALL_Y)).asnumpy() + assert np.isclose(softmax_out, expected_softmax_out).all() + + ex.backward(is_train=True) + grad_out = ex.grad_arrays[0][0].asnumpy() + k = int(label_nd[0].asscalar()) + expected_grad_out = np.zeros((SMALL_Y,)) + expected_grad_out[k] = -1 + assert np.isclose(grad_out - softmax_out, expected_grad_out).all() + + +# TODO: correctness of prelu (currently flaky) +def test_leaky_relu(): + a = -1*mx.nd.ones((LARGE_X, SMALL_Y)) + + def test_leaky(): + res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3) + assert res[-1][-1].asnumpy() == 0.3*a[-1][-1].asnumpy() + + def test_elu(): + res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3) + assert res[-1][-1].asnumpy() == 0.3*(np.exp(a[-1][-1].asnumpy())-1) + + def test_selu(): + lam = 1.0507009873554804934193349852946 + alpha = 1.6732632423543772848170429916717 + res = mx.nd.LeakyReLU(a, act_type="selu") + assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)) + + def test_rrelu(): + lower = 0.125 + upper = 0.333999991 + res = mx.nd.LeakyReLU(a, act_type="rrelu") + assert res[-1][-1].asnumpy() == (lower + upper) / 2 * a[-1][-1].asnumpy() + + test_leaky() + test_elu() + test_selu() + test_rrelu() + + +def test_pooling(): + a = mx.nd.ones((MEDIUM_X, MEDIUM_X, SMALL_Y, SMALL_Y)) + + def test_avg_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg') + assert res[-1][-1][-1][-1] == 1.0000001 + assert res.shape == SMALL_Y - 5 + 1 + + def test_max_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max') + assert res[-1][-1][-1][-1] == 1. + assert res.shape == SMALL_Y - 5 + 1 + + def test_sum_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum') + assert res[-1][-1][-1][-1] == 25 + assert res.shape == SMALL_Y - 5 + 1 + + def test_lp_pooling(): + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2) + assert res[-1][-1][-1][-1] == 5. + assert res.shape == SMALL_Y - 5 + 1 + + res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1) + assert res[-1][-1][-1][-1] == 25. + assert res.shape == SMALL_Y - 5 + 1 + + test_avg_pooling() + test_max_pooling() + test_sum_pooling() + test_lp_pooling() + + +def test_layer_norm(): + dtype = np.float32 + forward_check_eps = 1E-3 + axis = 1 + eps = 1E-5 + in_shape = (LARGE_X, SMALL_Y) + ctx = mx.cpu() + + def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): + if axis < 0: + axis += data.ndim + broadcast_shape = [1 for _ in range(data.ndim)] + broadcast_shape[axis] = data.shape[axis] + mean = data.mean(axis=axis, keepdims=True).astype(dtype) + var = data.var(axis=axis, keepdims=True).astype(dtype) + std = np.sqrt(var + dtype(eps)).astype(dtype) + out = np.reshape(gamma, broadcast_shape) * (data - mean) / std + \ + np.reshape(beta, broadcast_shape) + return out + data = np.random.normal(0, 1, in_shape).astype(dtype) + gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) + beta = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) + data_s = mx.symbol.Variable('data') + gamma_s = mx.symbol.Variable('gamma') + beta_s = mx.symbol.Variable('beta') + out_s = mx.symbol.LayerNorm(data=data_s, gamma=gamma_s, beta=beta_s, + axis=axis, eps=eps) + exe = out_s.simple_bind(ctx, data=in_shape) + exe.arg_dict['data'][:] = data + exe.arg_dict['gamma'][:] = gamma + exe.arg_dict['beta'][:] = beta + out_nd = exe.forward()[0] + out = npy_layer_norm(data, gamma, beta, axis, eps) + assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, + forward_check_eps) + +# TODO: correctness of dropout +# currently only test for dropout to work +# since testing for correctness involves flakiness issue #14288 +def test_dropout(): + shape = (10, 10) + x = mx.sym.var('data') + y = mx.sym.Dropout(x, p=1, cudnn_off=True) + exe = y.simple_bind(ctx=default_context(), data=shape) + exe.arg_arrays[0][:] = 1 + out = exe.forward(is_train=True) + out[0].wait_to_read() + + +def test_activation(): + a = mx.nd.ones((LARGE_X, SMALL_Y)) + test_x = -2 + a[-1, -1] = test_x + + # Hyperbolic tangent (tanh) + # y = (exp(x)-exp(-x))/(exp(x)+exp(-x)) + a = mx.nd.Activation(a, act_type="tanh") + tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2)) + assert a[-1][-1] == tanh_x + + # Recitified Linear Unit (relu) + # y = max(x,0) + a = mx.nd.Activation(a, act_type="relu") + assert a[-1][-1] == 0 + + # Sigmoid + # y = x/(1+abs(x)) + a = mx.nd.Activation(a, act_type="sigmoid") + sigmoid_x = 1/(1+math.exp(-test_x)) + assert a[-1][-1] == sigmoid_x + + # Soft Sign + # y = 1/(1+exp(-x)) + a = mx.nd.Activation(a, act_type="softsign") + softsign_x = test_x/(1+abs(test_x)) + assert a[-1][-1] == softsign_x + + +# TODO: correctness of batchnorm +# in future, we could test if mean, var of output +# matches target output's mean, var +def test_batchnorm(): + shape = (LARGE_X, SMALL_Y) + axis = 1 # default + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + + nch = shape[axis] + data = mx.nd.ones(shape=shape) + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_beta = mx.nd.random.uniform(shape=(nch,)) + bn_running_mean = mx.nd.zeros(nch) + bn_running_var = mx.nd.ones(nch) + + output = mx.nd.BatchNorm(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var) + output.wait_to_read() + + def test_add(): a = nd.ones(shape=(LARGE_X, SMALL_Y)) b = nd.ones(shape=(LARGE_X, SMALL_Y)) diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 8c030f5bc20e..3a66500957e0 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -17,6 +17,7 @@ import numpy as np import mxnet as mx + from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d from mxnet import gluon, nd from tests.python.unittest.common import with_seed