Skip to content

Commit

Permalink
Add reshape op supported by MKL-DNN (apache#12980)
Browse files Browse the repository at this point in the history
* Add reshape op supported by MKL-DNN

* fix build issue

* fix lint

* fix lint

* fix lint

* fix lint

* fix lint

* fix lint

* fix white space

* add unit test

* merge if blocks
  • Loading branch information
huangzhiyuan authored and anirudh2290 committed Dec 13, 2018
1 parent 090f222 commit c4a619c
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 10 deletions.
76 changes: 66 additions & 10 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,57 @@ DMLC_REGISTER_PARAMETER(StackParam);
DMLC_REGISTER_PARAMETER(SqueezeParam);
DMLC_REGISTER_PARAMETER(DepthToSpaceParam);

#if MXNET_USE_MKLDNN == 1
void MKLDNNReshape(const NDArray &in_data, const NDArray &out_data) {
MSHADOW_TYPE_SWITCH(in_data.dtype(), DType, {
auto this_mem = in_data.GetMKLDNNData();
auto out_dptr = out_data.data().dptr<DType>();
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
mkldnn::memory::desc this_desc = this_pd.desc();
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, this_pd.get_engine());
auto temp_mem = mkldnn::memory(pd, out_dptr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*this_mem, temp_mem));
MKLDNNStream::Get()->Submit();

// Removing out_data mkl_mem_ and store data in the default format
const_cast<NDArray &>(out_data).InvalidateMKLDNNData();
});
}

static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
// If inputs are supposed to be in MKLDNN format and
// MKLDNNsupport the data type or the shape. Then convert
// it to the output format and shape
if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape()) && req[0] != kAddTo) {
MKLDNNReshape(inputs[0], outputs[0]);
return;
}
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}

inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
#endif

NNVM_REGISTER_OP(Reshape)
.add_alias("reshape")
.describe(R"code(Reshapes the input array.
Expand Down Expand Up @@ -174,6 +225,14 @@ If the argument `reverse` is set to 1, then the special values are inferred from
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#else
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
Expand All @@ -182,6 +241,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
#endif
.add_argument("data", "NDArray-or-Symbol", "Input data to reshape.")
.add_arguments(ReshapeParam::__FIELDS__());

Expand Down Expand Up @@ -210,24 +270,18 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
#endif
}

#if MXNET_USE_MKLDNN == 1
static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
bool ret = ElemwiseStorageType<1, 1, false, false, false>(attrs, dev_mask, dispatch_mode,
in_attrs, out_attrs);
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask
&& in_attrs->at(0) == kDefaultStorage
&& out_attrs->at(0) == kDefaultStorage) {
*dispatch_mode = DispatchMode::kFComputeEx;
}
#endif
return ret;
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
#endif

NNVM_REGISTER_OP(Flatten)
.add_alias("flatten")
Expand Down Expand Up @@ -261,7 +315,9 @@ Example::
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", FlattenShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", FlattenStorageType)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_copy" })
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", FlattenEx)
Expand Down
31 changes: 31 additions & 0 deletions tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,37 @@ def __getitem__(self, key):
assert_almost_equal(y[0, 0, 0, 0], 0.016711406)
break

@with_seed()
def test_mkldnn_reshape():
def test_reshape_after_conv(dst_shape):
shape = (1,1,4,4)
data = mx.symbol.Variable('data')
conv = mx.symbol.Convolution(data=data, num_filter=16, kernel=(1, 1), pad=(0, 0), stride=(1, 1))
res = mx.symbol.reshape(data=conv, shape=dst_shape)
exe = res.simple_bind(mx.cpu(), data=shape, grad_req='null')

val1 = np.random.uniform(-1, 1, (4, 4))
val2 = np.random.uniform(-1, 1, (1, 1, 1, 1))
val3 = np.random.uniform(-1 ,1, (1))

exe.arg_arrays[0][:] = val1
exe.arg_arrays[1][:] = val2
exe.arg_arrays[2][:] = val3
outputs = exe.forward(is_train=False)[0].asnumpy()

conv_exe = conv.simple_bind(mx.cpu(), data=shape, grad_req='null')
conv_exe.arg_arrays[0][:] = val1
conv_exe.arg_arrays[1][:] = val2
conv_exe.arg_arrays[2][:] = val3
data_npy = conv_exe.forward(is_train=False)[0].asnumpy()
assert_almost_equal(outputs, data_npy.reshape(dst_shape))


# Test mkldnn reshape (Using shape)
test_cases = [(256), (16, 16), (4, 4, 16), (4, 4, 4, 4)]
for test_case in test_cases:
test_reshape_after_conv(test_case)


@with_seed()
def test_reshape_before_conv():
Expand Down

0 comments on commit c4a619c

Please sign in to comment.