Skip to content

Commit

Permalink
Add Large Tensor Support for Sequence, NN Ops (apache#15807)
Browse files Browse the repository at this point in the history
* sequence_last, sequence_reverse, sequence_mask

* working softmax_cross_entropy

* fix linting, add index_copy

* add softmax output

* add leaky relu

* add pooling

* add layernorm

* add dropout, activation, batchnorm and update layernorm

* address comments to remove some comments

* handling imports
  • Loading branch information
ChaiBapchya authored and Rohit Kumar Srivastava committed Sep 25, 2019
1 parent 29d0592 commit 5b5af38
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 8 deletions.
300 changes: 292 additions & 8 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import math
import numpy as np
import mxnet as mx
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed

Expand Down Expand Up @@ -299,9 +301,11 @@ def test_pick():
def test_depthtospace():
def numpy_depth_to_space(x, blocksize):
b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w])
tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h,
w])
tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2])
y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize])
y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize,
w * blocksize])
return y

shape_inp = (LARGE_X, 8, 4, 2)
Expand All @@ -315,9 +319,11 @@ def numpy_depth_to_space(x, blocksize):
def test_spacetodepth():
def numpy_space_to_depth(x, blocksize):
b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize])
tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize,
blocksize])
tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize])
y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize,
w // blocksize])
return y

shape_inp = (LARGE_X, 2, 8, 4)
Expand All @@ -327,6 +333,7 @@ def numpy_space_to_depth(x, blocksize):
output = mx.nd.space_to_depth(data, 2)
assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)


@with_seed()
def test_diag():
a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32)
Expand Down Expand Up @@ -358,7 +365,8 @@ def test_ravel_multi_index():
x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
indices_2d = [[x1, x2, x3], [y1, y2, y3]]
idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64),
shape=(LARGE_X, SMALL_Y))
idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y))
assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3

Expand All @@ -370,7 +378,8 @@ def test_unravel_index():
x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
original_2d_indices = [[x1, x2, x3], [y1, y2, y3]]
idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y))
indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64),
shape=(LARGE_X, SMALL_Y))
assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()


Expand Down Expand Up @@ -427,13 +436,288 @@ def test_topk():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
k = nd.topk(b, k=10, axis=0, dtype=np.int64)
assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y
ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False)
ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both",
is_ascend=False)
assert np.all(ind == val)
b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value")
assert l.sum() == np.sum(np.arange(0, SMALL_Y))


def test_sequence_mask():
# Sequence Mask input [max_sequence_length, batch_size, other_feature_dims]
# test with input batch_size = 2
a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)

# test as identity operator
b = nd.SequenceMask(a)
assert b[-1][0][1] == a[-1][0][1]
assert b.shape == a.shape

# test with default mask
b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]),
use_sequence_length=True)
assert b[0][1][-1] == a[0][1][-1] # first sequence of each batch kept
assert b[-1][-1][-1] != a[-1][-1][-1] # rest sequences masked
assert b[-1][-1][-1] == 0

# test with mask value
b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]),
use_sequence_length=True, value=-1)
assert b[-1][-1][-1] == -1


def test_sequence_reverse():
a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
# test as reverse operator
b = nd.SequenceReverse(a)
assert b[-1][0][0] == a[0][0][0]
assert b.shape == a.shape

# test with sequence length
b = nd.SequenceReverse(a, sequence_length=[2, 3])
assert b[1][0][0] == a[0][0][0] # check if reversed
assert b[-1][0][0] == a[-1][0][0] # check if intact
assert b.shape == a.shape


def test_sequence_last():
a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)

# test if returns last sequence
b = nd.SequenceLast(a)
assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor
assert b.shape == (2, SMALL_Y)

# test with sequence length
# parameter sequence_length - NDArray with shape (batch_size)
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
use_sequence_length=True)
# check if it takes 2nd sequence from the first batch
assert b[0][-1] == a[1][0][-1]


def test_softmax_cross_entropy():
# dtype of input data, mxnet cross entropy set explicitly to float64
# numpy implicitly takes care of double precision
batch_size = SMALL_Y
num_labels = LARGE_X
input_data = mx.nd.ones((batch_size, num_labels), dtype="float64")
input_label = mx.nd.zeros((batch_size,), dtype="float64")

true_softmax = np.full((batch_size, num_labels), (1 / num_labels))
# use 1/batch_size when softmax axis=0
# here 1/num_labels since softmax_cross_entropy uses default axis
# by default axis=1
np_one_hot_label = np.zeros((batch_size, num_labels))
np_one_hot_label[:, 0] = 1

true_softmax_cross_entropy = np.sum(-np.log(true_softmax) *
np_one_hot_label)
mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data,
input_label,
dtype="float64")
assert_almost_equal(mx_softmax_cross_entropy.asnumpy(),
true_softmax_cross_entropy, rtol=1e-3, atol=1e-5)


def test_index_copy():
x = mx.nd.zeros((LARGE_X, SMALL_Y))
t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y))
index = mx.nd.array([LARGE_X - 1])

x = mx.nd.contrib.index_copy(x, index, t)
assert x[-1][-1] == t[0][-1]


def testSoftmaxOutput():
x = mx.sym.Variable('x')
label = mx.sym.Variable('label')
x_nd = mx.nd.ones((LARGE_X, SMALL_Y))
grad_x = mx.nd.zeros((LARGE_X, SMALL_Y))
label_nd = mx.nd.ones((LARGE_X))

sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0,
use_ignore=False)
ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd},
args_grad={'x': grad_x})

ex.forward(is_train=True)
softmax_out = ex.outputs[0][0].asnumpy()
expected_softmax_out = (1/SMALL_Y)*mx.nd.ones((SMALL_Y)).asnumpy()
assert np.isclose(softmax_out, expected_softmax_out).all()

ex.backward(is_train=True)
grad_out = ex.grad_arrays[0][0].asnumpy()
k = int(label_nd[0].asscalar())
expected_grad_out = np.zeros((SMALL_Y,))
expected_grad_out[k] = -1
assert np.isclose(grad_out - softmax_out, expected_grad_out).all()


# TODO: correctness of prelu (currently flaky)
def test_leaky_relu():
a = -1*mx.nd.ones((LARGE_X, SMALL_Y))

def test_leaky():
res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3)
assert res[-1][-1].asnumpy() == 0.3*a[-1][-1].asnumpy()

def test_elu():
res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3)
assert res[-1][-1].asnumpy() == 0.3*(np.exp(a[-1][-1].asnumpy())-1)

def test_selu():
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
res = mx.nd.LeakyReLU(a, act_type="selu")
assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1))

def test_rrelu():
lower = 0.125
upper = 0.333999991
res = mx.nd.LeakyReLU(a, act_type="rrelu")
assert res[-1][-1].asnumpy() == (lower + upper) / 2 * a[-1][-1].asnumpy()

test_leaky()
test_elu()
test_selu()
test_rrelu()


def test_pooling():
a = mx.nd.ones((MEDIUM_X, MEDIUM_X, SMALL_Y, SMALL_Y))

def test_avg_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg')
assert res[-1][-1][-1][-1] == 1.0000001
assert res.shape == SMALL_Y - 5 + 1

def test_max_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max')
assert res[-1][-1][-1][-1] == 1.
assert res.shape == SMALL_Y - 5 + 1

def test_sum_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum')
assert res[-1][-1][-1][-1] == 25
assert res.shape == SMALL_Y - 5 + 1

def test_lp_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2)
assert res[-1][-1][-1][-1] == 5.
assert res.shape == SMALL_Y - 5 + 1

res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1)
assert res[-1][-1][-1][-1] == 25.
assert res.shape == SMALL_Y - 5 + 1

test_avg_pooling()
test_max_pooling()
test_sum_pooling()
test_lp_pooling()


def test_layer_norm():
dtype = np.float32
forward_check_eps = 1E-3
axis = 1
eps = 1E-5
in_shape = (LARGE_X, SMALL_Y)
ctx = mx.cpu()

def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
if axis < 0:
axis += data.ndim
broadcast_shape = [1 for _ in range(data.ndim)]
broadcast_shape[axis] = data.shape[axis]
mean = data.mean(axis=axis, keepdims=True).astype(dtype)
var = data.var(axis=axis, keepdims=True).astype(dtype)
std = np.sqrt(var + dtype(eps)).astype(dtype)
out = np.reshape(gamma, broadcast_shape) * (data - mean) / std + \
np.reshape(beta, broadcast_shape)
return out
data = np.random.normal(0, 1, in_shape).astype(dtype)
gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
beta = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
data_s = mx.symbol.Variable('data')
gamma_s = mx.symbol.Variable('gamma')
beta_s = mx.symbol.Variable('beta')
out_s = mx.symbol.LayerNorm(data=data_s, gamma=gamma_s, beta=beta_s,
axis=axis, eps=eps)
exe = out_s.simple_bind(ctx, data=in_shape)
exe.arg_dict['data'][:] = data
exe.arg_dict['gamma'][:] = gamma
exe.arg_dict['beta'][:] = beta
out_nd = exe.forward()[0]
out = npy_layer_norm(data, gamma, beta, axis, eps)
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
forward_check_eps)

# TODO: correctness of dropout
# currently only test for dropout to work
# since testing for correctness involves flakiness issue #14288
def test_dropout():
shape = (10, 10)
x = mx.sym.var('data')
y = mx.sym.Dropout(x, p=1, cudnn_off=True)
exe = y.simple_bind(ctx=default_context(), data=shape)
exe.arg_arrays[0][:] = 1
out = exe.forward(is_train=True)
out[0].wait_to_read()


def test_activation():
a = mx.nd.ones((LARGE_X, SMALL_Y))
test_x = -2
a[-1, -1] = test_x

# Hyperbolic tangent (tanh)
# y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
a = mx.nd.Activation(a, act_type="tanh")
tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2))
assert a[-1][-1] == tanh_x

# Recitified Linear Unit (relu)
# y = max(x,0)
a = mx.nd.Activation(a, act_type="relu")
assert a[-1][-1] == 0

# Sigmoid
# y = x/(1+abs(x))
a = mx.nd.Activation(a, act_type="sigmoid")
sigmoid_x = 1/(1+math.exp(-test_x))
assert a[-1][-1] == sigmoid_x

# Soft Sign
# y = 1/(1+exp(-x))
a = mx.nd.Activation(a, act_type="softsign")
softsign_x = test_x/(1+abs(test_x))
assert a[-1][-1] == softsign_x


# TODO: correctness of batchnorm
# in future, we could test if mean, var of output
# matches target output's mean, var
def test_batchnorm():
shape = (LARGE_X, SMALL_Y)
axis = 1 # default
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]

nch = shape[axis]
data = mx.nd.ones(shape=shape)
bn_gamma = mx.nd.random.uniform(shape=(nch,))
bn_beta = mx.nd.random.uniform(shape=(nch,))
bn_running_mean = mx.nd.zeros(nch)
bn_running_var = mx.nd.ones(nch)

output = mx.nd.BatchNorm(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var)
output.wait_to_read()


def test_add():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
b = nd.ones(shape=(LARGE_X, SMALL_Y))
Expand Down
1 change: 1 addition & 0 deletions tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import mxnet as mx

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed
Expand Down

0 comments on commit 5b5af38

Please sign in to comment.