-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add large tensor nightly tests for MKL-DNN operators #16184
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,8 +212,15 @@ def test_dot(): | |
def test_FullyConnected(): | ||
a = nd.ones(shape=(LARGE_X, SMALL_Y)) | ||
b = nd.ones(shape=(SMALL_Y, SMALL_Y)) | ||
res = nd.FullyConnected(a, b, num_hidden=b.shape[1], no_bias=True) | ||
assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1] | ||
c = nd.ones(shape=(b.shape[0],)) | ||
|
||
# w/o bias | ||
res = nd.FullyConnected(a, b, num_hidden=b.shape[0], no_bias=True) | ||
assert np.sum(res[-1].asnumpy() == a.shape[1]) == b.shape[0] | ||
|
||
# w/ bias | ||
res = nd.FullyConnected(a, b, c, num_hidden=b.shape[0], no_bias=False) | ||
assert np.sum(res[-1].asnumpy() == a.shape[1] + 1) == b.shape[0] | ||
|
||
|
||
def test_broadcast(): | ||
|
@@ -272,6 +279,7 @@ def test_slice_assign(): | |
def test_expand_dims(): | ||
a = nd.ones(shape=(LARGE_X, SMALL_Y)) | ||
res = nd.expand_dims(a, axis=1) | ||
res.wait_to_read() | ||
assert a[0][0][0] == 1 | ||
assert res.shape == (a.shape[0], 1, a.shape[1]) | ||
|
||
|
@@ -401,10 +409,14 @@ def test_unravel_index(): | |
|
||
|
||
def test_transpose(): | ||
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y) | ||
t = b.T | ||
assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1] | ||
assert t.shape == (SMALL_Y, LARGE_X) | ||
test_dtypes = [np.float32, np.int64] | ||
for dtype in test_dtypes: | ||
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y, dtype=dtype) | ||
t = b.T | ||
assert t.shape == (SMALL_Y, LARGE_X) | ||
ref_out = np.transpose(b.asnumpy()) | ||
assert_almost_equal(t.asnumpy(), ref_out, rtol=1e-10) | ||
|
||
|
||
|
||
def test_swapaxes(): | ||
|
@@ -423,9 +435,10 @@ def test_flip(): | |
|
||
def test_softmax(): | ||
input_data = mx.nd.ones((SMALL_Y, LARGE_X)) | ||
true_output = np.full((SMALL_Y, LARGE_X), (1 / SMALL_Y)) | ||
output = nd.softmax(input_data, axis=0) | ||
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) | ||
for axis in [0, 1]: | ||
true_output = np.full((SMALL_Y, LARGE_X), (1 / input_data.shape[axis])) | ||
output = nd.softmax(input_data, axis=axis) | ||
assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5) | ||
|
||
|
||
def test_argsort(): | ||
|
@@ -619,12 +632,19 @@ def testSoftmaxOutput(): | |
|
||
sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0, | ||
use_ignore=False) | ||
|
||
ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, | ||
args_grad={'x': grad_x}) | ||
args_grad=None) | ||
ex.forward(is_train=False) | ||
softmax_out = ex.outputs[0][0].asnumpy() | ||
expected_softmax_out = (1 / SMALL_Y) * mx.nd.ones((SMALL_Y)).asnumpy() | ||
assert np.isclose(softmax_out, expected_softmax_out).all() | ||
|
||
ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd}, | ||
args_grad={'x': grad_x}) | ||
ex.forward(is_train=True) | ||
softmax_out = ex.outputs[0][0].asnumpy() | ||
expected_softmax_out = (1/SMALL_Y)*mx.nd.ones((SMALL_Y)).asnumpy() | ||
expected_softmax_out = (1 / SMALL_Y) * mx.nd.ones((SMALL_Y)).asnumpy() | ||
assert np.isclose(softmax_out, expected_softmax_out).all() | ||
|
||
ex.backward(is_train=True) | ||
|
@@ -782,8 +802,29 @@ def test_activation(): | |
# in future, we could test if mean, var of output | ||
# matches target output's mean, var | ||
def test_batchnorm(): | ||
shape = (LARGE_X, SMALL_Y) | ||
def get_np_mean_var(data, running_mean, running_var, eps, use_global_status=True): | ||
if not use_global_status: | ||
# train mode, calculate the real mean and var | ||
mean = np.mean(data, axis=(0, 2, 3)) | ||
mean_broad = np.expand_dims(mean, axis=0) | ||
mean_broad = np.expand_dims(mean_broad, axis=2) | ||
mean_broad = np.expand_dims(mean_broad, axis=3) | ||
mean_broad = np.broadcast_to(mean_broad, data.shape) | ||
var = np.square(data - mean_broad) | ||
var = np.mean(var, axis=(0, 2, 3)) | ||
else: | ||
# inference mode, use running_mean and running_var instead | ||
mean = np.full((data.shape[1],), running_mean) | ||
var = np.full((data.shape[1],), running_var) | ||
|
||
# calculate the inverse of standard variance | ||
invstdvar = 1. / np.sqrt(var + eps) | ||
return mean, invstdvar | ||
|
||
# Here use 4D input to cover mkldnn BN and non-mkldnn BN | ||
shape = (1, 2, LARGE_X, SMALL_Y) | ||
axis = 1 # default | ||
eps = 1e-3 | ||
|
||
nch = shape[axis] | ||
data = mx.nd.ones(shape=shape) | ||
|
@@ -793,8 +834,21 @@ def test_batchnorm(): | |
bn_running_var = mx.nd.ones(nch) | ||
|
||
output = mx.nd.BatchNorm(data, bn_gamma, bn_beta, | ||
bn_running_mean, bn_running_var) | ||
assert output.shape == shape | ||
bn_running_mean, bn_running_var, output_mean_var=True) | ||
assert output[0].shape == shape | ||
mean, invstdvar = output[1], output[2] | ||
|
||
np_mean, np_invstdvar = get_np_mean_var(data.asnumpy(), bn_running_mean.asnumpy(), bn_running_var.asnumpy(), | ||
eps, use_global_status=True) | ||
assert_almost_equal(mean.asnumpy(), np_mean) | ||
assert_almost_equal(invstdvar.asnumpy(), np_invstdvar) | ||
|
||
|
||
def test_elemwise_add(): | ||
a = nd.ones(shape=(LARGE_X, SMALL_Y)) | ||
b = nd.ones(shape=(LARGE_X, SMALL_Y)) | ||
res = nd.elemwise_add(a, b) | ||
assert np.sum(res[-1].asnumpy() == 2) == a.shape[1] | ||
|
||
|
||
def test_add(): | ||
|
@@ -944,19 +998,25 @@ def test_reshape_like(): | |
|
||
|
||
def test_flatten(): | ||
a = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y).reshape((LARGE_X//2, 2, SMALL_Y)) | ||
b = nd.flatten(a) | ||
assert b[-1][-1] == (LARGE_X-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we removing these asserts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is related to the different precision of |
||
assert b[-1][0] == (LARGE_X-2) | ||
assert b.shape == (LARGE_X//2, SMALL_Y*2) | ||
test_dtypes = [np.float32, np.int64] | ||
for dtype in test_dtypes: | ||
a = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y, dtype=dtype).reshape((LARGE_X//2, 2, SMALL_Y)) | ||
b = nd.flatten(a) | ||
# Here we removed the value asserts due to different precision of `int64` and `float32`. | ||
# For `float32`, it will lose some precision when `LARGE_X` is too large, that is `LARGE_X-1` | ||
# and `LARGE_X-2` can not represent the accurate value in the current situation. | ||
assert b.shape == (LARGE_X//2, SMALL_Y*2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also test one of the values inside tensor b? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
assert_almost_equal(b[-1,-1].asnumpy(), a[-1,-1,-1].asnumpy(), rtol=1e-8) | ||
|
||
|
||
def test_concat(): | ||
a = nd.array(np.ones((SMALL_Y, LARGE_X))) | ||
b = nd.array(np.zeros((SMALL_Y, LARGE_X))) | ||
c = nd.concat(a, b, dim=0) | ||
assert c.shape == (b.shape[0]*2, LARGE_X) | ||
|
||
for axis in [0, 1]: | ||
c = nd.concat(a, b, dim=axis) | ||
c.wait_to_read() | ||
assert c.shape[axis] == b.shape[axis] * 2 | ||
assert c.shape[1-axis] == b.shape[1-axis] | ||
|
||
def test_stack(): | ||
a = nd.array(np.ones((SMALL_Y, LARGE_X))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the intuition behind adding this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to cover
w/ bias
andw/o bias