Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add Large tensor vector test cases #15941

Merged
merged 31 commits into from
Sep 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
eb0d524
add random ops
ChaiBapchya Aug 16, 2019
ccfd4f8
add shuffle to test large array
ChaiBapchya Aug 17, 2019
fd8cc04
shape evaluation after value check
ChaiBapchya Aug 17, 2019
2408f73
add log, exponent, power ops
ChaiBapchya Aug 17, 2019
cb6fd0f
fix sequence reverse issue in test_large_array and add sequence ops t…
ChaiBapchya Aug 19, 2019
01fefe6
add binary arithmetic
ChaiBapchya Aug 19, 2019
9a4eb2e
fix lint, minor mistakes in large_array; add nn op to tensor
ChaiBapchya Aug 19, 2019
351411e
Trigger notification coz of test_operator.test_laop_6 error
ChaiBapchya Aug 19, 2019
a9ce8fe
Trigger notification coz of test_operator.test_laop_6 error
ChaiBapchya Aug 19, 2019
627ba82
Trigger notification bcoz R failures
ChaiBapchya Aug 21, 2019
3f12e1e
address comments
ChaiBapchya Aug 21, 2019
4b5a835
normal distribution assert statement fix; randint dtype check
ChaiBapchya Aug 21, 2019
1274f14
correct layernorm and shuffle
ChaiBapchya Aug 22, 2019
03563bd
layer norm numpy flaky hence removed, dropout shape fix
ChaiBapchya Aug 22, 2019
f984a0d
comment not working ops
ChaiBapchya Aug 22, 2019
acb1eab
fix multi
ChaiBapchya Aug 23, 2019
0de3a00
Trigger notification
ChaiBapchya Aug 23, 2019
a9e01a1
Merge branch 'master' into lts_vector
ChaiBapchya Aug 23, 2019
a47beb6
fix seq reverse, uncomment seq mask as it works
ChaiBapchya Aug 23, 2019
9bf8f7f
index fix and uncomment test
ChaiBapchya Aug 23, 2019
ceb04ef
index fix
ChaiBapchya Aug 24, 2019
268d143
seq_reverse index fix
ChaiBapchya Aug 26, 2019
aca1edd
uncomment seq reverse test and handle static typecasts
ChaiBapchya Aug 26, 2019
dd17bec
removing commented ops
ChaiBapchya Aug 27, 2019
57a10ac
Merge branch 'master' into lts_vector
ChaiBapchya Aug 27, 2019
1e9349a
resolve merge conflict
ChaiBapchya Aug 27, 2019
513581f
Merge branch 'master' into lts_vector
ChaiBapchya Aug 29, 2019
b38eed5
teardown, lint, remove redundant functions
ChaiBapchya Aug 29, 2019
31521f3
fix shape assertions and randint low,high
ChaiBapchya Aug 30, 2019
cb1c5fb
remove waits, add teardown to large_array, change randint assert in l…
ChaiBapchya Aug 30, 2019
fe6a15f
Merge branch 'master' into lts_vector
ChaiBapchya Aug 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/operator/sequence_last-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ struct SequenceLastParam : public dmlc::Parameter<SequenceLastParam> {
template <int req>
struct SequenceLastKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
const IType *idx, int offset1, int offset2,
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in,
const IType *idx, index_t offset1, index_t offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
const index_t seqpos = static_cast<index_t>(idx[opos[0]]) - 1;
const index_t ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
KERNEL_ASSIGN(out[i], req, in[ipos]);
}
};

struct SequenceLastGradKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad,
const IType *idx, int offset1, int offset2,
MSHADOW_XINLINE static void Map(index_t i, DType *in_grad, const DType *out_grad,
const IType *idx, index_t offset1, index_t offset2,
mshadow::Shape<2> oshape) {
const auto opos = mxnet_op::unravel(i, oshape);
const int seqpos = static_cast<int>(idx[opos[0]]) - 1;
const int ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
const index_t seqpos = static_cast<index_t>(idx[opos[0]]) - 1;
const index_t ipos = seqpos * offset1 + opos[0] * offset2 + opos[1];
in_grad[ipos] += out_grad[i];
}
};
Expand All @@ -103,8 +103,8 @@ class SequenceLastOp : public Operator {
int axis = param_.axis;
int out_size = out.size(0) * out.size(1);
int max_seq_len = data.size(axis);
int offset1 = axis ? out.size(1) : out_size;
int offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);
index_t offset1 = axis ? out.size(1) : out_size;
index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);

MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
mxnet_op::Kernel<SequenceLastKernel<req_type>, xpu>::Launch(
Expand All @@ -126,8 +126,8 @@ class SequenceLastOp : public Operator {
int out_size = batch * rest;

int max_seq_len = in_grad.size(axis);
int offset1 = axis ? rest : out_size;
int offset2 = axis ? (max_seq_len * rest) : rest;
index_t offset1 = axis ? rest : out_size;
index_t offset2 = axis ? (max_seq_len * rest) : rest;

mxnet_op::Kernel<SequenceLastGradKernel, xpu>::Launch(
s, out_size, in_grad.dptr_, out_grad.dptr_, indices.dptr_, offset1,
Expand Down
14 changes: 7 additions & 7 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,30 @@ struct SequenceReverseParam : public dmlc::Parameter<SequenceReverseParam> {
template <OpReqType req>
struct ReverseKernel {
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(const int i, DType *const out_data,
MSHADOW_XINLINE static void Map(const index_t i, DType *const out_data,
const DType *const in_data,
const index_t max_seq_len,
const index_t batch_size,
const index_t other_dim, const index_t numel,
const IType *const indices) {
const index_t batch = i / (max_seq_len * other_dim);
const int id = (i / other_dim) % max_seq_len;
const index_t id = (i / other_dim) % max_seq_len;
const index_t j = i % other_dim;
const index_t num_seq =
indices ? static_cast<index_t>(indices[batch]) : max_seq_len;
const index_t padded_periods = max_seq_len - num_seq;
// padded part
if (padded_periods > 0 && id < static_cast<int>(padded_periods)) {
const int padded_in_offset =
if (padded_periods > 0 && id < padded_periods) {
const index_t padded_in_offset =
(id + num_seq) * batch_size * other_dim + batch * other_dim;

KERNEL_ASSIGN(out_data[padded_in_offset + j], req,
in_data[padded_in_offset + j]);
}
// unpadded part
if (id < static_cast<int>(num_seq)) {
const int in_offset = id * batch_size * other_dim + batch * other_dim;
const int out_offset =
if (id < num_seq) {
const index_t in_offset = id * batch_size * other_dim + batch * other_dim;
const index_t out_offset =
numel - (id + 1 + padded_periods) * batch_size * other_dim +
batch * other_dim;

Expand Down
67 changes: 40 additions & 27 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed
from tests.python.unittest.common import with_seed, teardown

# dimension constants
MEDIUM_X = 10000
Expand Down Expand Up @@ -56,10 +56,8 @@ def test_ndarray_ones():
def test_ndarray_convert():
a = nd.zeros(shape=(LARGE_X, SMALL_Y))
b = a.astype(np.int32)
b.wait_to_read()
assert b.dtype == np.int32
b = a.tostype('row_sparse')
b.wait_to_read()
assert isinstance(b, mx.nd.sparse.RowSparseNDArray)


Expand All @@ -79,15 +77,16 @@ def test_ndarray_random_randint():
a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64)
low = mx.nd.array([low_large_value], dtype='int64')
high = mx.nd.array([high_large_value], dtype='int64')
assert a.__gt__(low) and a.__lt__(high)
assert a >= low and a < high
assert a[-1][0].dtype == np.int64


@with_seed()
def test_ndarray_random_exponential():
scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.exponential(scale=scale_array, shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved


@with_seed()
Expand All @@ -96,25 +95,25 @@ def test_ndarray_random_gamma():
beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.gamma(alpha=alpha_array, beta=beta_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_multinomial():
# test 1 shape dimension
probs = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
a = nd.random.multinomial(probs)
assert a.shape == (LARGE_X,)
assert a[-1] >= 0
assert a.shape == (LARGE_X,)
# test for NDArray multi-dimension shape
a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_X, SMALL_Y)
assert a[-1][0][0] >= 0
assert a.shape == (LARGE_X, SMALL_X, SMALL_Y)
# test log_likelihood output shape
a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y), get_prob=True)
assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape
assert a[-1][0][0] >= 0
assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape


@with_seed()
Expand All @@ -123,8 +122,8 @@ def test_ndarray_random_generalized_negative_binomial():
mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.generalized_negative_binomial(mu=mu_array, alpha=alpha_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -133,8 +132,8 @@ def test_ndarray_random_negative_binomial():
p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.negative_binomial(k=k_array, p=p_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -144,26 +143,40 @@ def test_ndarray_random_normal():
a = nd.random.normal(loc=loc_array, scale=scale_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0


@with_seed()
def test_ndarray_random_poisson():
lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
a = nd.random.poisson(lam=lambda_array, shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_randn():
a = nd.random.randn(LARGE_X, SMALL_Y)
assert a.shape == (LARGE_X, SMALL_Y)
assert a[-1][0] >= 0
# TODO: Once PR for randn ndarray dtype for loc,scale param merged
# TODO: Once PR #15772 for randn ndarray dtype for loc,scale param merged
# Add check for (x,y,m,n) where x,y shape of loc,scale and m,n input shape


@with_seed()
def test_ndarray_random_shuffle():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
a[-1] == 3 # assign 3 to entire last row
a = nd.random.shuffle(a)
# slice first column from shuffled array
# pass LARGE_X values to numpy instead of LARGE_X*SMALL_Y
# could have assigned to last column (so as to pass SMALL_Y)
# but shuffle operation is performed along first axis
unique_a = np.unique(a[:, 0].asnumpy())
assert len(unique_a) == 2 # only 2 unique values
assert unique_a[0] == 1 # first unique value is 1
assert unique_a[1] == 3 # second unique value is 3
assert a.shape[0] == (LARGE_X, SMALL_Y)


def test_ndarray_empty():
a = nd.empty((LARGE_X, SMALL_Y))
assert a.shape == (LARGE_X, SMALL_Y)
Expand Down Expand Up @@ -277,7 +290,6 @@ def test_Dense(ctx=mx.cpu(0)):
linear = gluon.nn.Dense(100)
linear.initialize(ctx=ctx)
res = linear(data)
res.wait_to_read()
assert res.shape == (50000000, 100)


Expand Down Expand Up @@ -386,22 +398,22 @@ def test_unravel_index():
def test_transpose():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = b.T
assert t.shape == (SMALL_Y, LARGE_X)
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]
assert t.shape == (SMALL_Y, LARGE_X)


def test_swapaxes():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = nd.swapaxes(b, dim1=0, dim2=1)
assert t.shape == (SMALL_Y, LARGE_X)
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]
assert t.shape == (SMALL_Y, LARGE_X)


def test_flip():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = nd.flip(b, axis=0)
assert t.shape == (LARGE_X, SMALL_Y)
assert np.sum(t[-1, :].asnumpy() == 0) == b.shape[1]
assert t.shape == (LARGE_X, SMALL_Y)


def test_softmax():
Expand Down Expand Up @@ -535,7 +547,9 @@ def test_sequence_reverse():
assert b.shape == a.shape

# test with sequence length
b = nd.SequenceReverse(a, sequence_length=[2, 3])
# 2 rows of batch 1 and 3 rows of batch 2 reversed
b = nd.SequenceReverse(a, sequence_length=nd.array([2, 3]),
use_sequence_length=True)
assert b[1][0][0] == a[0][0][0] # check if reversed
assert b[-1][0][0] == a[-1][0][0] # check if intact
assert b.shape == a.shape
Expand All @@ -546,7 +560,7 @@ def test_sequence_last():

# test if returns last sequence
b = nd.SequenceLast(a)
assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor
assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2,SMALL_Y) tensor
assert b.shape == (2, SMALL_Y)

# test with sequence length
Expand Down Expand Up @@ -715,17 +729,18 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
forward_check_eps)


# TODO: correctness of dropout
# currently only test for dropout to work
# since testing for correctness involves flakiness issue #14288
def test_dropout():
shape = (10, 10)
shape = (LARGE_X, SMALL_Y)
x = mx.sym.var('data')
y = mx.sym.Dropout(x, p=1, cudnn_off=True)
exe = y.simple_bind(ctx=default_context(), data=shape)
exe.arg_arrays[0][:] = 1
out = exe.forward(is_train=True)
out[0].wait_to_read()
assert out.shape == out.shape


def test_activation():
Expand All @@ -736,7 +751,7 @@ def test_activation():
# Hyperbolic tangent (tanh)
# y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
a = mx.nd.Activation(a, act_type="tanh")
tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2))
tanh_x = (np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add space around operators. Do this across the entire file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8 / pylint both didn't give error for this. I'll do it but is there some other linting tool apart from PEP8 and pylint?

If it's google style guide tool?
Is this the way to go?
https://stackoverflow.com/questions/29597618/is-there-a-tool-to-lint-python-based-on-the-google-style-guide

assert a[-1][-1] == tanh_x

# Recitified Linear Unit (relu)
Expand All @@ -763,8 +778,6 @@ def test_activation():
def test_batchnorm():
shape = (LARGE_X, SMALL_Y)
axis = 1 # default
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]

nch = shape[axis]
data = mx.nd.ones(shape=shape)
Expand All @@ -775,7 +788,7 @@ def test_batchnorm():

output = mx.nd.BatchNorm(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var)
output.wait_to_read()
assert output.shape == shape


def test_add():
Expand Down
Loading