Skip to content

Commit

Permalink
Add Large tensor vector test cases (apache#15941)
Browse files Browse the repository at this point in the history
* add random ops

* add shuffle to test large array

* shape evaluation after value check

* add log, exponent, power ops

* fix sequence reverse issue in test_large_array and add sequence ops to test_large_vector

* add binary arithmetic

* fix lint, minor mistakes in large_array; add nn op to tensor

* Trigger notification coz of test_operator.test_laop_6 error

* Trigger notification coz of test_operator.test_laop_6 error

* Trigger notification bcoz R failures

* address comments

* normal distribution assert statement fix; randint dtype check

* correct layernorm and shuffle

* layer norm numpy flaky hence removed, dropout shape fix

* comment not working ops

* fix multi

* Trigger notification

* fix seq reverse, uncomment seq mask as it works

* index fix and uncomment test

* index fix

* seq_reverse index fix

* uncomment seq reverse test and handle static typecasts

* removing commented ops

* resolve merge conflict

* teardown, lint, remove redundant functions

* fix shape assertions and randint low,high

* remove waits, add teardown to large_array, change randint assert in large array
  • Loading branch information
ChaiBapchya authored and gyshi committed Sep 7, 2019
1 parent 0f1f10f commit 3387c81
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 49 deletions.
24 changes: 12 additions & 12 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ struct SequenceLastParam : public dmlc::Parameter<SequenceLastParam> {
template <int req>
struct SequenceLastKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
const IType *idx, int offset1, int offset2,
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in,
const IType *idx, index_t offset1, index_t offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
const index_t seqpos = static_cast<index_t>(idx[opos[0]]) - 1;
const index_t ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
KERNEL_ASSIGN(out[i], req, in[ipos]);
}
};

struct SequenceLastGradKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad,
const IType *idx, int offset1, int offset2,
MSHADOW_XINLINE static void Map(index_t i, DType *in_grad, const DType *out_grad,
const IType *idx, index_t offset1, index_t offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
const index_t seqpos = static_cast<index_t>(idx[opos[0]]) - 1;
const index_t ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
in_grad[ipos] += out_grad[i];
}
};
Expand All @@ -103,8 +103,8 @@ class SequenceLastOp : public Operator {
int axis = param_.axis;
int out_size = out.size(0) * out.size(1);
int max_seq_len = data.size(axis);
int offset1 = axis ? out.size(1) : out_size;
int offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);
index_t offset1 = axis ? out.size(1) : out_size;
index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);

MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
mxnet_op::Kernel<SequenceLastKernel<req_type>, xpu>::Launch(
Expand All @@ -126,8 +126,8 @@ class SequenceLastOp : public Operator {
int out_size = batch * rest;

int max_seq_len = in_grad.size(axis);
int offset1 = axis ? rest : out_size;
int offset2 = axis ? (max_seq_len * rest) : rest;
index_t offset1 = axis ? rest : out_size;
index_t offset2 = axis ? (max_seq_len * rest) : rest;

mxnet_op::Kernel<SequenceLastGradKernel, xpu>::Launch(
s, out_size, in_grad.dptr_, out_grad.dptr_, indices.dptr_, offset1,
Expand Down
14 changes: 7 additions & 7 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,30 @@ struct SequenceReverseParam : public dmlc::Parameter<SequenceReverseParam> {
template <OpReqType req>
struct ReverseKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(const int i, DType *const out_data,
MSHADOW_XINLINE static void Map(const index_t i, DType *const out_data,
const DType *const in_data,
const index_t max_seq_len,
const index_t batch_size,
const index_t other_dim, const index_t numel,
const IType *const indices) {
const index_t batch = i / (max_seq_len * other_dim);
const int id = (i / other_dim) % max_seq_len;
const index_t id = (i / other_dim) % max_seq_len;
const index_t j = i % other_dim;
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
const index_t padded_periods = max_seq_len - num_seq;
// padded part
if (padded_periods > 0 && id < static_cast<int>(padded_periods)) {
const int padded_in_offset =
if (padded_periods > 0 && id < padded_periods) {
const index_t padded_in_offset =
(id + num_seq) * batch_size * other_dim + batch * other_dim;

KERNEL_ASSIGN(out_data[padded_in_offset + j], req,
in_data[padded_in_offset + j]);
}
// unpadded part
if (id < static_cast<int>(num_seq)) {
const int in_offset = id * batch_size * other_dim + batch * other_dim;
const int out_offset =
if (id < num_seq) {
const index_t in_offset = id * batch_size * other_dim + batch * other_dim;
const index_t out_offset =
numel - (id + 1 + padded_periods) * batch_size * other_dim +
batch * other_dim;

Expand Down
67 changes: 40 additions & 27 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed
from tests.python.unittest.common import with_seed, teardown

# dimension constants
MEDIUM_X = 10000
Expand Down Expand Up @@ -56,10 +56,8 @@ def test_ndarray_ones():
def test_ndarray_convert():
a = nd.zeros(shape=(LARGE_X, SMALL_Y))
b = a.astype(np.int32)
b.wait_to_read()
assert b.dtype == np.int32
b = a.tostype('row_sparse')
b.wait_to_read()
assert isinstance(b, mx.nd.sparse.RowSparseNDArray)


Expand All @@ -79,15 +77,16 @@ def test_ndarray_random_randint():
a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64)
low = mx.nd.array([low_large_value], dtype='int64')
high = mx.nd.array([high_large_value], dtype='int64')
assert a.__gt__(low) and a.__lt__(high)
assert a >= low and a < high
assert a[-1][0].dtype == np.int64


@with_seed()
def test_ndarray_random_exponential():
scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.exponential(scale=scale_array, shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -96,25 +95,25 @@ def test_ndarray_random_gamma():
beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.gamma(alpha=alpha_array, beta=beta_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_multinomial():
# test 1 shape dimension
probs = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
a = nd.random.multinomial(probs)
assert a.shape == (LARGE_X,)
assert a[-1] >= 0
assert a.shape == (LARGE_X,)
# test for NDArray multi-dimension shape
a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_X, SMALL_Y)
assert a[-1][0][0] >= 0
assert a.shape == (LARGE_X, SMALL_X, SMALL_Y)
# test log_likelihood output shape
a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y), get_prob=True)
assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape
assert a[-1][0][0] >= 0
assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape


@with_seed()
Expand All @@ -123,8 +122,8 @@ def test_ndarray_random_generalized_negative_binomial():
mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.generalized_negative_binomial(mu=mu_array, alpha=alpha_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -133,8 +132,8 @@ def test_ndarray_random_negative_binomial():
p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.negative_binomial(k=k_array, p=p_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -144,26 +143,40 @@ def test_ndarray_random_normal():
a = nd.random.normal(loc=loc_array, scale=scale_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0


@with_seed()
def test_ndarray_random_poisson():
lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.poisson(lam=lambda_array, shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_randn():
a = nd.random.randn(LARGE_X, SMALL_Y)
assert a.shape == (LARGE_X, SMALL_Y)
assert a[-1][0] >= 0
# TODO: Once PR for randn ndarray dtype for loc,scale param merged
# TODO: Once PR #15772 for randn ndarray dtype for loc,scale param merged
# Add check for (x,y,m,n) where x,y shape of loc,scale and m,n input shape


@with_seed()
def test_ndarray_random_shuffle():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
a[-1] == 3 # assign 3 to entire last row
a = nd.random.shuffle(a)
# slice first column from shuffled array
# pass LARGE_X values to numpy instead of LARGE_X*SMALL_Y
# could have assigned to last column (so as to pass SMALL_Y)
# but shuffle operation is performed along first axis
unique_a = np.unique(a[:, 0].asnumpy())
assert len(unique_a) == 2 # only 2 unique values
assert unique_a[0] == 1 # first unique value is 1
assert unique_a[1] == 3 # second unique value is 3
assert a.shape[0] == (LARGE_X, SMALL_Y)


def test_ndarray_empty():
a = nd.empty((LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
Expand Down Expand Up @@ -277,7 +290,6 @@ def test_Dense(ctx=mx.cpu(0)):
linear = gluon.nn.Dense(100)
linear.initialize(ctx=ctx)
res = linear(data)
res.wait_to_read()
assert res.shape == (50000000, 100)


Expand Down Expand Up @@ -386,22 +398,22 @@ def test_unravel_index():
def test_transpose():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = b.T
assert t.shape == (SMALL_Y, LARGE_X)
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]
assert t.shape == (SMALL_Y, LARGE_X)


def test_swapaxes():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = nd.swapaxes(b, dim1=0, dim2=1)
assert t.shape == (SMALL_Y, LARGE_X)
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]
assert t.shape == (SMALL_Y, LARGE_X)


def test_flip():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = nd.flip(b, axis=0)
assert t.shape == (LARGE_X, SMALL_Y)
assert np.sum(t[-1, :].asnumpy() == 0) == b.shape[1]
assert t.shape == (LARGE_X, SMALL_Y)


def test_softmax():
Expand Down Expand Up @@ -535,7 +547,9 @@ def test_sequence_reverse():
assert b.shape == a.shape

# test with sequence length
b = nd.SequenceReverse(a, sequence_length=[2, 3])
# 2 rows of batch 1 and 3 rows of batch 2 reversed
b = nd.SequenceReverse(a, sequence_length=nd.array([2, 3]),
use_sequence_length=True)
assert b[1][0][0] == a[0][0][0] # check if reversed
assert b[-1][0][0] == a[-1][0][0] # check if intact
assert b.shape == a.shape
Expand All @@ -546,7 +560,7 @@ def test_sequence_last():

# test if returns last sequence
b = nd.SequenceLast(a)
assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor
assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2,SMALL_Y) tensor
assert b.shape == (2, SMALL_Y)

# test with sequence length
Expand Down Expand Up @@ -715,17 +729,18 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
forward_check_eps)


# TODO: correctness of dropout
# currently only test for dropout to work
# since testing for correctness involves flakiness issue #14288
def test_dropout():
shape = (10, 10)
shape = (LARGE_X, SMALL_Y)
x = mx.sym.var('data')
y = mx.sym.Dropout(x, p=1, cudnn_off=True)
exe = y.simple_bind(ctx=default_context(), data=shape)
exe.arg_arrays[0][:] = 1
out = exe.forward(is_train=True)
out[0].wait_to_read()
assert out.shape == out.shape


def test_activation():
Expand All @@ -736,7 +751,7 @@ def test_activation():
# Hyperbolic tangent (tanh)
# y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
a = mx.nd.Activation(a, act_type="tanh")
tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2))
tanh_x = (np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x))
assert a[-1][-1] == tanh_x

# Recitified Linear Unit (relu)
Expand All @@ -763,8 +778,6 @@ def test_activation():
def test_batchnorm():
shape = (LARGE_X, SMALL_Y)
axis = 1 # default
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]

nch = shape[axis]
data = mx.nd.ones(shape=shape)
Expand All @@ -775,7 +788,7 @@ def test_batchnorm():

output = mx.nd.BatchNorm(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var)
output.wait_to_read()
assert output.shape == shape


def test_add():
Expand Down
Loading

0 comments on commit 3387c81

Please sign in to comment.