Skip to content

Commit

Permalink
Add more gluon computation on MKLDNN with memory operation(slice, res…
Browse files Browse the repository at this point in the history
…hape etc.) (apache#10764)
  • Loading branch information
juliusshufan authored and zheng-da committed Jun 28, 2018
1 parent 6c4d663 commit be47024
Showing 1 changed file with 111 additions and 4 deletions.
115 changes: 111 additions & 4 deletions tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
MKL-DNN related test cases
"""

import logging
import os
from sys import platform
import mxnet as mx
import numpy as np
import sys,os,logging
from mxnet import gluon
from mxnet.gluon import nn
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, '../unittest/'))
from common import setup_module, with_seed
from nose.tools import raises
from mxnet.test_utils import assert_almost_equal


Expand All @@ -35,7 +40,7 @@ def test_mkldnn_install():
"""
logging.basicConfig(level=logging.INFO)

if not platform.startswith('linux'):
if not sys.platform.startswith('linux'):
logging.info("Bypass mkldnn install test for non-Linux OS")
return

Expand Down Expand Up @@ -144,5 +149,107 @@ def test_mkldnn_ndarray_slice():
# trigger computation on ndarray slice
assert_almost_equal(y[0].asnumpy()[0, 0, 0], 0.3376348)

@with_seed()
def test_reshape_before_conv():
"""
This test will test gluon Conv2d computation on mkldnn with ndarray reshape
"""
class Net(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
with self.name_scope():
self.conv0 = nn.Conv2D(10, (3, 3))
self.conv1 = nn.Conv2D(5, (3, 3))

def hybrid_forward(self, F, x):
x_reshape = x.reshape((0, 0, 20, 5))
y = self.conv0(x_reshape)
y_reshape = y.reshape((0, 0, 9, 6))
out = self.conv1(y_reshape)
return out
x = mx.nd.random.uniform(shape=(2, 4, 10, 10))
x.attach_grad()
net = Net()
net.collect_params().initialize()
with mx.autograd.record():
out1 = net(x)
out1.backward()
dx1 = x.grad
net.hybridize()
with mx.autograd.record():
out2 = net(x)
out2.backward()
mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6)
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6)


@with_seed()
def test_slice_before_conv():
"""
This test will test gluon Conv2d computation on mkldnn with ndarray slice
"""
class Net(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
with self.name_scope():
self.conv0 = nn.Conv2D(4, (3, 3))
self.conv1 = nn.Conv2D(4, (3, 3))

def hybrid_forward(self, F, x):
x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 10, 10))
y = self.conv0(x_slice)
y_slice = y.slice(begin=(1, 0, 2, 2), end=(2, 1, 7, 7))
out = self.conv1(y_slice)
return out
x = mx.nd.random.uniform(shape=(2, 10, 10, 10))
x.attach_grad()
net = Net()
net.collect_params().initialize()
with mx.autograd.record():
out1 = net(x)
out1.backward()
dx1 = x.grad
net.hybridize()
with mx.autograd.record():
out2 = net(x)
out2.backward()
mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6)
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6)


@with_seed()
def test_slice_reshape_before_conv():
"""
This test will test gluon Conv2d computation on mkldnn with ndarray reshape and slice
"""
class Net(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
with self.name_scope():
self.conv0 = nn.Conv2D(4, (3, 3))
self.conv1 = nn.Conv2D(4, (3, 3))

def hybrid_forward(self, F, x):
x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 8, 9))
y = self.conv0(x_slice)
y_reshape = y.reshape((0, 0, 14, 3))
out = self.conv1(y_reshape)
return out
x = mx.nd.random.uniform(shape=(2, 10, 10, 10))
x.attach_grad()
net = Net()
net.collect_params().initialize()
with mx.autograd.record():
out1 = net(x)
out1.backward()
dx1 = x.grad
net.hybridize()
with mx.autograd.record():
out2 = net(x)
out2.backward()
mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6)
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6)


if __name__ == '__main__':
test_mkldnn_install()

0 comments on commit be47024

Please sign in to comment.