Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix large array tests #16328

Merged
merged 13 commits into from
Oct 14, 2019
2 changes: 1 addition & 1 deletion src/operator/contrib/index_copy-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->at(0)[i], in_attrs->at(2)[i]);
}
}
// The the length of the fitrst dim of copied tensor
// The the length of the first dim of copied tensor
// must equal to the size of index vector
CHECK_EQ(in_attrs->at(1)[0], in_attrs->at(2)[0]);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/index_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ namespace op {

struct index_copy_fwd_cpu {
template<typename DType, typename IType>
static void Map(int i,
static void Map(index_t i,
const DType* new_tensor,
const IType* idx,
DType* out_tensor,
int dim_size) {
DType* out_ptr = out_tensor + static_cast<int>(idx[i]) * dim_size;
DType* out_ptr = out_tensor + static_cast<index_t>(idx[i]) * dim_size;
const DType* new_ptr = new_tensor + i * dim_size;
std::memcpy(out_ptr, new_ptr, sizeof(DType) * dim_size);
}
Expand Down
147 changes: 67 additions & 80 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed, teardown
from tests.python.unittest.common import with_seed, with_post_test_cleanup
from nose.tools import with_setup

# dimension constants
MEDIUM_X = 10000
Expand Down Expand Up @@ -84,20 +85,20 @@ def test_ndarray_random_randint():

@with_seed()
def test_ndarray_random_exponential():
scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
a = nd.random.exponential(scale=scale_array, shape=(SMALL_X, SMALL_Y))
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_gamma():
alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
a = nd.random.gamma(alpha=alpha_array, beta=beta_array,
shape=(SMALL_X, SMALL_Y))
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -108,50 +109,50 @@ def test_ndarray_random_multinomial():
assert a[-1] >= 0
assert a.shape == (LARGE_X,)
# test for NDArray multi-dimension shape
a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y))
a = nd.random.multinomial(probs, shape=(2, SMALL_Y))
assert a[-1][0][0] >= 0
assert a.shape == (LARGE_X, SMALL_X, SMALL_Y)
assert a.shape == (LARGE_X, 2, SMALL_Y)
# test log_likelihood output shape
a = nd.random.multinomial(probs, shape=(SMALL_X, SMALL_Y), get_prob=True)
assert a[-1][0][0] >= 0
assert a[0].shape == (LARGE_X, SMALL_X, SMALL_Y) and a[0].shape == a[1].shape
a = nd.random.multinomial(probs, shape=(2, SMALL_Y), get_prob=True)
assert a[0][0][0][0] >= 0
assert a[0].shape == (LARGE_X, 2, SMALL_Y) and a[0].shape == a[1].shape


@with_seed()
def test_ndarray_random_generalized_negative_binomial():
alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
a = nd.random.generalized_negative_binomial(mu=mu_array, alpha=alpha_array,
shape=(SMALL_X, SMALL_Y))
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_negative_binomial():
k_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
k_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
a = nd.random.negative_binomial(k=k_array, p=p_array,
shape=(SMALL_X, SMALL_Y))
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_normal():
scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
loc_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
loc_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
a = nd.random.normal(loc=loc_array, scale=scale_array,
shape=(SMALL_X, SMALL_Y))
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)


@with_seed()
def test_ndarray_random_poisson():
lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_Y))
lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
a = nd.random.poisson(lam=lambda_array, shape=(SMALL_X, SMALL_Y))
assert a[-1][0][0][0] >= 0
assert a.shape == (MEDIUM_X, SMALL_Y, SMALL_X, SMALL_Y)
assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)


@with_seed()
Expand All @@ -165,7 +166,7 @@ def test_ndarray_random_randn():
@with_seed()
def test_ndarray_random_shuffle():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
a[-1] == 3 # assign 3 to entire last row
a[-1] = 3 # assign 3 to entire last row
a = nd.random.shuffle(a)
# slice first column from shuffled array
# pass LARGE_X values to numpy instead of LARGE_X*SMALL_Y
Expand All @@ -175,7 +176,7 @@ def test_ndarray_random_shuffle():
assert len(unique_a) == 2 # only 2 unique values
assert unique_a[0] == 1 # first unique value is 1
assert unique_a[1] == 3 # second unique value is 3
assert a.shape[0] == (LARGE_X, SMALL_Y)
assert a.shape == (LARGE_X, SMALL_Y)


def test_ndarray_empty():
Expand Down Expand Up @@ -269,6 +270,7 @@ def test_slice_assign():
def test_expand_dims():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
res = nd.expand_dims(a, axis=1)
assert a[0][0][0] == 1
assert res.shape == (a.shape[0], 1, a.shape[1])


Expand Down Expand Up @@ -561,7 +563,7 @@ def test_sequence_last():

# test if returns last sequence
b = nd.SequenceLast(a)
assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2,SMALL_Y) tensor
assert_almost_equal(b.asnumpy(), a[-1].asnumpy()) # only checks for (2, SMALL_Y) tensor
assert b.shape == (2, SMALL_Y)

# test with sequence length
Expand Down Expand Up @@ -600,7 +602,7 @@ def test_softmax_cross_entropy():
def test_index_copy():
x = mx.nd.zeros((LARGE_X, SMALL_Y))
t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y))
index = mx.nd.array([LARGE_X - 1])
index = mx.nd.array([LARGE_X - 1], dtype="int64")

x = mx.nd.contrib.index_copy(x, index, t)
assert x[-1][-1] == t[0][-1]
Expand Down Expand Up @@ -637,23 +639,23 @@ def test_leaky_relu():

def test_leaky():
res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3)
assert res[-1][-1].asnumpy() == 0.3*a[-1][-1].asnumpy()
assert_almost_equal(res[-1][-1].asnumpy(), 0.3*a[-1][-1].asnumpy(), atol=1e-3, rtol=1e-3)

def test_elu():
res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3)
assert res[-1][-1].asnumpy() == 0.3*(np.exp(a[-1][-1].asnumpy())-1)
assert_almost_equal(res[-1][-1].asnumpy(), 0.3*(np.exp(a[-1][-1].asnumpy())-1), atol=1e-3, rtol=1e-3)

def test_selu():
lam = 1.0507009873554804934193349852946
alpha = 1.6732632423543772848170429916717
res = mx.nd.LeakyReLU(a, act_type="selu")
assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1))
assert_almost_equal(res[-1][-1].asnumpy(), (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)), atol=1e-3, rtol=1e-3)

def test_rrelu():
lower = 0.125
upper = 0.333999991
res = mx.nd.LeakyReLU(a, act_type="rrelu")
assert res[-1][-1].asnumpy() == (lower + upper) / 2 * a[-1][-1].asnumpy()
assert_almost_equal(res[0][-1][-1].asnumpy(), (lower + upper) / 2 * a[-1][-1].asnumpy(), atol=1e-3, rtol=1e-3)

test_leaky()
test_elu()
Expand All @@ -662,31 +664,31 @@ def test_rrelu():


def test_pooling():
a = mx.nd.ones((MEDIUM_X, MEDIUM_X, SMALL_Y, SMALL_Y))
a = mx.nd.ones((MEDIUM_X, 200, SMALL_Y, SMALL_Y))

def test_avg_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg')
assert res[-1][-1][-1][-1] == 1.0000001
assert res.shape == SMALL_Y - 5 + 1
assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 1.0000001, atol=1e-3, rtol=1e-3)
assert res.shape[-1] == SMALL_Y - 5 + 1

def test_max_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max')
assert res[-1][-1][-1][-1] == 1.
assert res.shape == SMALL_Y - 5 + 1
assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 1., atol=1e-3, rtol=1e-3)
assert res.shape[-1] == SMALL_Y - 5 + 1

def test_sum_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum')
assert res[-1][-1][-1][-1] == 25
assert res.shape == SMALL_Y - 5 + 1
assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 25, atol=1e-3, rtol=1e-3)
assert res.shape[-1] == SMALL_Y - 5 + 1

def test_lp_pooling():
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2)
assert res[-1][-1][-1][-1] == 5.
assert res.shape == SMALL_Y - 5 + 1
assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 5., atol=1e-3, rtol=1e-3)
assert res.shape[-1] == SMALL_Y - 5 + 1

res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1)
assert res[-1][-1][-1][-1] == 25.
assert res.shape == SMALL_Y - 5 + 1
assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 25., atol=1e-3, rtol=1e-3)
assert res.shape[-1] == SMALL_Y - 5 + 1

test_avg_pooling()
test_max_pooling()
Expand Down Expand Up @@ -741,36 +743,37 @@ def test_dropout():
exe = y.simple_bind(ctx=default_context(), data=shape)
exe.arg_arrays[0][:] = 1
out = exe.forward(is_train=True)
assert out.shape == out.shape
nd.waitall()
assert out[0].shape == shape


def test_activation():
a = mx.nd.ones((LARGE_X, SMALL_Y))
x = mx.nd.ones((LARGE_X, SMALL_Y))
test_x = -2
a[-1, -1] = test_x
x[-1, -1] = test_x

# Hyperbolic tangent (tanh)
# y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
a = mx.nd.Activation(a, act_type="tanh")
tanh_x = (np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x))
assert a[-1][-1] == tanh_x
y = mx.nd.Activation(x, act_type="tanh")
tanh_x = ((np.exp(test_x)-np.exp(-test_x))/(np.exp(test_x)+np.exp(-test_x)))
assert y[-1][-1] == np.float32(tanh_x)

# Recitified Linear Unit (relu)
# y = max(x,0)
a = mx.nd.Activation(a, act_type="relu")
assert a[-1][-1] == 0
y = mx.nd.Activation(x, act_type="relu")
assert y[-1][-1] == 0

# Sigmoid
# y = x/(1+abs(x))
a = mx.nd.Activation(a, act_type="sigmoid")
sigmoid_x = 1/(1+math.exp(-test_x))
assert a[-1][-1] == sigmoid_x
y = mx.nd.Activation(x, act_type="sigmoid")
sigmoid_x = (1/(1+math.exp(-test_x)))
assert_almost_equal(y[-1][-1].asnumpy(), np.float32(sigmoid_x), atol=1e-3, rtol=1e-3)

# Soft Sign
# y = 1/(1+exp(-x))
a = mx.nd.Activation(a, act_type="softsign")
softsign_x = test_x/(1+abs(test_x))
assert a[-1][-1] == softsign_x
y = mx.nd.Activation(x, act_type="softsign")
softsign_x = (test_x/(1+abs(test_x)))
assert y[-1][-1] == np.float32(softsign_x)


# TODO: correctness of batchnorm
Expand Down Expand Up @@ -924,8 +927,7 @@ def test_copy_to():
b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
c = a.copyto(b)
assert c is b
print(b)
assert b[0][-1] == LARGE_X-1
assert b[-1][-1] == SMALL_Y-1


def test_zeros_like():
Expand Down Expand Up @@ -957,24 +959,17 @@ def test_flatten():
assert b.shape == (LARGE_X//2, SMALL_Y*2)


def test_expand_dims():
a = nd.array(np.ones((SMALL_Y, LARGE_X)))
b = nd.expand_dims(a, axis=1)
nd.waitall()
assert b.shape == (SMALL_Y, 1, LARGE_X)


def test_concat():
a = nd.array(np.ones((SMALL_Y, LARGE_X)))
b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
c = nd.concat(a,b, dim=0)
c = nd.concat(a, b, dim=0)
assert c.shape == (b.shape[0]*2, LARGE_X)


def test_stack():
a = nd.array(np.ones((SMALL_Y, LARGE_X)))
b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
c = nd.stack(a,b, axis=1)
c = nd.stack(a, b, axis=1)
assert c.shape == (b.shape[0], 2, LARGE_X)


Expand Down Expand Up @@ -1019,7 +1014,7 @@ def test_max():
def test_norm():
a = np.array(np.full((1, LARGE_X), 3))
b = np.array(np.full((1, LARGE_X), 4))
c = nd.array(np.concatenate((a,b), axis=0))
c = nd.array(np.concatenate((a, b), axis=0))
d = nd.norm(c, ord=2, axis=0)
e = nd.norm(c, ord=1, axis=0)
assert d.shape[0] == LARGE_X
Expand All @@ -1031,7 +1026,7 @@ def test_norm():
def test_argmax():
a = np.ones((SMALL_Y, LARGE_X))
b = np.zeros((SMALL_Y, LARGE_X))
c = nd.array(np.concatenate((a,b), axis=0))
c = nd.array(np.concatenate((a, b), axis=0))
d = nd.argmax(c, axis=0)
assert d.shape[0] == LARGE_X
assert d[-1] == d[0] == 0
Expand All @@ -1040,12 +1035,13 @@ def test_argmax():
def test_relu():
def frelu(x):
return np.maximum(x, 0.0)

def frelu_grad(x):
return 1.0 * (x > 0.0)
shape = (SMALL_Y, LARGE_X)
x = mx.symbol.Variable("x")
y = mx.sym.relu(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
xa = np.random.uniform(low=-1.0, high=1.0, size=shape)
eps = 1e-4
xa[abs(xa) < eps] = 1.0
ya = frelu(xa)
Expand All @@ -1059,7 +1055,7 @@ def fsigmoid(a):
shape = (SMALL_Y, LARGE_X)
x = mx.symbol.Variable("x")
y = mx.sym.sigmoid(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
xa = np.random.uniform(low=-1.0, high=1.0, size=shape)
ya = fsigmoid(xa)
check_symbolic_forward(y, [xa], [ya])

Expand Down Expand Up @@ -1116,15 +1112,6 @@ def test_idiv():
assert c[0][-1] == 2


def test_imod():
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
c = a
c %= b
assert c.shape == a.shape
assert c[0][-1] == 1


def test_eq():
a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
Expand Down Expand Up @@ -1198,7 +1185,7 @@ def test_slice_axis():


def test_one_hot():
#default dtype of ndarray is float32 which cannot index elements over 2^32
# default dtype of ndarray is float32 which cannot index elements over 2^32
a = nd.array([1, (VLARGE_X - 1)], dtype=np.int64)
b = nd.one_hot(a, VLARGE_X)
b[0][1] == 1
Expand Down
Loading