Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add large tensor nightly tests for MKL-DNN operators (#16184)
Browse files Browse the repository at this point in the history
* merge ut into original script

* use LARGE_X to define the input shape

* add inline comments

* address comments

* rebase code
  • Loading branch information
wuxun-zhang authored and marcoabreu committed Nov 19, 2019
1 parent 135c42c commit 4f14bf4
Showing 1 changed file with 82 additions and 22 deletions.
104 changes: 82 additions & 22 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,15 @@ def test_dot():
def test_FullyConnected():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
b = nd.ones(shape=(SMALL_Y, SMALL_Y))
res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True)
assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1]
c = nd.ones(shape=(b.shape[0],))

# w/o bias
res = nd.FullyConnected(a, b, num_hidden=b.shape[0], no_bias=True)
assert np.sum(res[-1].asnumpy() == a.shape[1]) == b.shape[0]

# w/ bias
res = nd.FullyConnected(a, b, c, num_hidden=b.shape[0], no_bias=False)
assert np.sum(res[-1].asnumpy() == a.shape[1] + 1) == b.shape[0]


def test_broadcast():
Expand Down Expand Up @@ -272,6 +279,7 @@ def test_slice_assign():
def test_expand_dims():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
res = nd.expand_dims(a, axis=1)
res.wait_to_read()
assert a[0][0][0] == 1
assert res.shape == (a.shape[0], 1, a.shape[1])

Expand Down Expand Up @@ -401,10 +409,14 @@ def test_unravel_index():


def test_transpose():
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
t = b.T
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]
assert t.shape == (SMALL_Y, LARGE_X)
test_dtypes = [np.float32, np.int64]
for dtype in test_dtypes:
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y, dtype=dtype)
t = b.T
assert t.shape == (SMALL_Y, LARGE_X)
ref_out = np.transpose(b.asnumpy())
assert_almost_equal(t.asnumpy(), ref_out, rtol=1e-10)



def test_swapaxes():
Expand All @@ -423,9 +435,10 @@ def test_flip():

def test_softmax():
input_data = mx.nd.ones((SMALL_Y, LARGE_X))
true_output = np.full((SMALL_Y, LARGE_X), (1 / SMALL_Y))
output = nd.softmax(input_data, axis=0)
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)
for axis in [0, 1]:
true_output = np.full((SMALL_Y, LARGE_X), (1 / input_data.shape[axis]))
output = nd.softmax(input_data, axis=axis)
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)


def test_argsort():
Expand Down Expand Up @@ -619,12 +632,19 @@ def testSoftmaxOutput():

sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0,
use_ignore=False)

ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd},
args_grad={'x': grad_x})
args_grad=None)
ex.forward(is_train=False)
softmax_out = ex.outputs[0][0].asnumpy()
expected_softmax_out = (1 / SMALL_Y) * mx.nd.ones((SMALL_Y)).asnumpy()
assert np.isclose(softmax_out, expected_softmax_out).all()

ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd},
args_grad={'x': grad_x})
ex.forward(is_train=True)
softmax_out = ex.outputs[0][0].asnumpy()
expected_softmax_out = (1/SMALL_Y)*mx.nd.ones((SMALL_Y)).asnumpy()
expected_softmax_out = (1 / SMALL_Y) * mx.nd.ones((SMALL_Y)).asnumpy()
assert np.isclose(softmax_out, expected_softmax_out).all()

ex.backward(is_train=True)
Expand Down Expand Up @@ -782,8 +802,29 @@ def test_activation():
# in future, we could test if mean, var of output
# matches target output's mean, var
def test_batchnorm():
shape = (LARGE_X, SMALL_Y)
def get_np_mean_var(data, running_mean, running_var, eps, use_global_status=True):
if not use_global_status:
# train mode, calculate the real mean and var
mean = np.mean(data, axis=(0, 2, 3))
mean_broad = np.expand_dims(mean, axis=0)
mean_broad = np.expand_dims(mean_broad, axis=2)
mean_broad = np.expand_dims(mean_broad, axis=3)
mean_broad = np.broadcast_to(mean_broad, data.shape)
var = np.square(data - mean_broad)
var = np.mean(var, axis=(0, 2, 3))
else:
# inference mode, use running_mean and running_var instead
mean = np.full((data.shape[1],), running_mean)
var = np.full((data.shape[1],), running_var)

# calculate the inverse of standard variance
invstdvar = 1. / np.sqrt(var + eps)
return mean, invstdvar

# Here use 4D input to cover mkldnn BN and non-mkldnn BN
shape = (1, 2, LARGE_X, SMALL_Y)
axis = 1 # default
eps = 1e-3

nch = shape[axis]
data = mx.nd.ones(shape=shape)
Expand All @@ -793,8 +834,21 @@ def test_batchnorm():
bn_running_var = mx.nd.ones(nch)

output = mx.nd.BatchNorm(data, bn_gamma, bn_beta,
bn_running_mean, bn_running_var)
assert output.shape == shape
bn_running_mean, bn_running_var, output_mean_var=True)
assert output[0].shape == shape
mean, invstdvar = output[1], output[2]

np_mean, np_invstdvar = get_np_mean_var(data.asnumpy(), bn_running_mean.asnumpy(), bn_running_var.asnumpy(),
eps, use_global_status=True)
assert_almost_equal(mean.asnumpy(), np_mean)
assert_almost_equal(invstdvar.asnumpy(), np_invstdvar)


def test_elemwise_add():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
b = nd.ones(shape=(LARGE_X, SMALL_Y))
res = nd.elemwise_add(a, b)
assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]


def test_add():
Expand Down Expand Up @@ -944,19 +998,25 @@ def test_reshape_like():


def test_flatten():
a = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y).reshape((LARGE_X//2, 2, SMALL_Y))
b = nd.flatten(a)
assert b[-1][-1] == (LARGE_X-1)
assert b[-1][0] == (LARGE_X-2)
assert b.shape == (LARGE_X//2, SMALL_Y*2)
test_dtypes = [np.float32, np.int64]
for dtype in test_dtypes:
a = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y, dtype=dtype).reshape((LARGE_X//2, 2, SMALL_Y))
b = nd.flatten(a)
# Here we removed the value asserts due to different precision of `int64` and `float32`.
# For `float32`, it will lose some precision when `LARGE_X` is too large, that is `LARGE_X-1`
# and `LARGE_X-2` can not represent the accurate value in the current situation.
assert b.shape == (LARGE_X//2, SMALL_Y*2)
assert_almost_equal(b[-1,-1].asnumpy(), a[-1,-1,-1].asnumpy(), rtol=1e-8)


def test_concat():
a = nd.array(np.ones((SMALL_Y, LARGE_X)))
b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
c = nd.concat(a, b, dim=0)
assert c.shape == (b.shape[0]*2, LARGE_X)

for axis in [0, 1]:
c = nd.concat(a, b, dim=axis)
c.wait_to_read()
assert c.shape[axis] == b.shape[axis] * 2
assert c.shape[1-axis] == b.shape[1-axis]

def test_stack():
a = nd.array(np.ones((SMALL_Y, LARGE_X)))
Expand Down

0 comments on commit 4f14bf4

Please sign in to comment.