diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 4501c3b6261c..dc9e9141a359 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -19,10 +19,15 @@ MKL-DNN related test cases """ -import logging -import os -from sys import platform +import mxnet as mx import numpy as np +import sys,os,logging +from mxnet import gluon +from mxnet.gluon import nn +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../unittest/')) +from common import setup_module, with_seed +from nose.tools import raises from mxnet.test_utils import assert_almost_equal @@ -35,7 +40,7 @@ def test_mkldnn_install(): """ logging.basicConfig(level=logging.INFO) - if not platform.startswith('linux'): + if not sys.platform.startswith('linux'): logging.info("Bypass mkldnn install test for non-Linux OS") return @@ -144,5 +149,107 @@ def test_mkldnn_ndarray_slice(): # trigger computation on ndarray slice assert_almost_equal(y[0].asnumpy()[0, 0, 0], 0.3376348) +@with_seed() +def test_reshape_before_conv(): + """ + This test will test gluon Conv2d computation on mkldnn with ndarray reshape + """ + class Net(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.conv0 = nn.Conv2D(10, (3, 3)) + self.conv1 = nn.Conv2D(5, (3, 3)) + + def hybrid_forward(self, F, x): + x_reshape = x.reshape((0, 0, 20, 5)) + y = self.conv0(x_reshape) + y_reshape = y.reshape((0, 0, 9, 6)) + out = self.conv1(y_reshape) + return out + x = mx.nd.random.uniform(shape=(2, 4, 10, 10)) + x.attach_grad() + net = Net() + net.collect_params().initialize() + with mx.autograd.record(): + out1 = net(x) + out1.backward() + dx1 = x.grad + net.hybridize() + with mx.autograd.record(): + out2 = net(x) + out2.backward() + mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + + +@with_seed() +def test_slice_before_conv(): + """ + This test will test gluon Conv2d computation on mkldnn with ndarray slice + """ + class Net(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.conv0 = nn.Conv2D(4, (3, 3)) + self.conv1 = nn.Conv2D(4, (3, 3)) + + def hybrid_forward(self, F, x): + x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 10, 10)) + y = self.conv0(x_slice) + y_slice = y.slice(begin=(1, 0, 2, 2), end=(2, 1, 7, 7)) + out = self.conv1(y_slice) + return out + x = mx.nd.random.uniform(shape=(2, 10, 10, 10)) + x.attach_grad() + net = Net() + net.collect_params().initialize() + with mx.autograd.record(): + out1 = net(x) + out1.backward() + dx1 = x.grad + net.hybridize() + with mx.autograd.record(): + out2 = net(x) + out2.backward() + mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + + +@with_seed() +def test_slice_reshape_before_conv(): + """ + This test will test gluon Conv2d computation on mkldnn with ndarray reshape and slice + """ + class Net(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.conv0 = nn.Conv2D(4, (3, 3)) + self.conv1 = nn.Conv2D(4, (3, 3)) + + def hybrid_forward(self, F, x): + x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 8, 9)) + y = self.conv0(x_slice) + y_reshape = y.reshape((0, 0, 14, 3)) + out = self.conv1(y_reshape) + return out + x = mx.nd.random.uniform(shape=(2, 10, 10, 10)) + x.attach_grad() + net = Net() + net.collect_params().initialize() + with mx.autograd.record(): + out1 = net(x) + out1.backward() + dx1 = x.grad + net.hybridize() + with mx.autograd.record(): + out2 = net(x) + out2.backward() + mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + + if __name__ == '__main__': test_mkldnn_install()