diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 0faa668caf97..2ffeabc11ae3 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -103,6 +103,57 @@ DMLC_REGISTER_PARAMETER(StackParam); DMLC_REGISTER_PARAMETER(SqueezeParam); DMLC_REGISTER_PARAMETER(DepthToSpaceParam); +#if MXNET_USE_MKLDNN == 1 +void MKLDNNReshape(const NDArray &in_data, const NDArray &out_data) { + MSHADOW_TYPE_SWITCH(in_data.dtype(), DType, { + auto this_mem = in_data.GetMKLDNNData(); + auto out_dptr = out_data.data().dptr(); + mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc(); + mkldnn::memory::desc this_desc = this_pd.desc(); + mkldnn::memory::dims dims(this_desc.data.dims, + this_desc.data.dims + this_desc.data.ndims); + auto this_dtype = static_cast(this_desc.data.data_type); + auto this_format = static_cast(GetDefaultFormat(this_desc)); + mkldnn::memory::desc data_md(dims, this_dtype, this_format); + mkldnn::memory::primitive_desc pd(data_md, this_pd.get_engine()); + auto temp_mem = mkldnn::memory(pd, out_dptr); + MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*this_mem, temp_mem)); + MKLDNNStream::Get()->Submit(); + + // Removing out_data mkl_mem_ and store data in the default format + const_cast(out_data).InvalidateMKLDNNData(); + }); +} + +static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + // If inputs are supposed to be in MKLDNN format and + // MKLDNNsupport the data type or the shape. Then convert + // it to the output format and shape + if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape()) && req[0] != kAddTo) { + MKLDNNReshape(inputs[0], outputs[0]); + return; + } + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); +} + +inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} +#endif + NNVM_REGISTER_OP(Reshape) .add_alias("reshape") .describe(R"code(Reshapes the input array. @@ -174,6 +225,14 @@ If the argument `reverse` is set to 1, then the special values are inferred from .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_copy"}) .set_attr("FCompute", UnaryOp::IdentityCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", ReshapeComputeExCPU) +.set_attr("FInferStorageType", ReshapeStorageType) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#else .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; @@ -182,6 +241,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from [](const NodeAttrs& attrs){ return std::vector{true}; }) +#endif .add_argument("data", "NDArray-or-Symbol", "Input data to reshape.") .add_arguments(ReshapeParam::__FIELDS__()); @@ -210,6 +270,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, #endif } +#if MXNET_USE_MKLDNN == 1 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -217,17 +278,10 @@ static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); - bool ret = ElemwiseStorageType<1, 1, false, false, false>(attrs, dev_mask, dispatch_mode, - in_attrs, out_attrs); -#if MXNET_USE_MKLDNN == 1 - if (dev_mask == mshadow::cpu::kDevMask - && in_attrs->at(0) == kDefaultStorage - && out_attrs->at(0) == kDefaultStorage) { - *dispatch_mode = DispatchMode::kFComputeEx; - } -#endif - return ret; + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); } +#endif NNVM_REGISTER_OP(Flatten) .add_alias("flatten") @@ -261,7 +315,9 @@ Example:: .set_num_outputs(1) .set_attr("FInferShape", FlattenShape) .set_attr("FInferType", ElemwiseType<1, 1>) +#if MXNET_USE_MKLDNN == 1 .set_attr("FInferStorageType", FlattenStorageType) +#endif .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_copy" }) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FComputeEx", FlattenEx) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index d9d3abfc3ced..aae66bd9a60e 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -97,6 +97,37 @@ def __getitem__(self, key): assert_almost_equal(y[0, 0, 0, 0], 0.016711406) break +@with_seed() +def test_mkldnn_reshape(): + def test_reshape_after_conv(dst_shape): + shape = (1,1,4,4) + data = mx.symbol.Variable('data') + conv = mx.symbol.Convolution(data=data, num_filter=16, kernel=(1, 1), pad=(0, 0), stride=(1, 1)) + res = mx.symbol.reshape(data=conv, shape=dst_shape) + exe = res.simple_bind(mx.cpu(), data=shape, grad_req='null') + + val1 = np.random.uniform(-1, 1, (4, 4)) + val2 = np.random.uniform(-1, 1, (1, 1, 1, 1)) + val3 = np.random.uniform(-1 ,1, (1)) + + exe.arg_arrays[0][:] = val1 + exe.arg_arrays[1][:] = val2 + exe.arg_arrays[2][:] = val3 + outputs = exe.forward(is_train=False)[0].asnumpy() + + conv_exe = conv.simple_bind(mx.cpu(), data=shape, grad_req='null') + conv_exe.arg_arrays[0][:] = val1 + conv_exe.arg_arrays[1][:] = val2 + conv_exe.arg_arrays[2][:] = val3 + data_npy = conv_exe.forward(is_train=False)[0].asnumpy() + assert_almost_equal(outputs, data_npy.reshape(dst_shape)) + + + # Test mkldnn reshape (Using shape) + test_cases = [(256), (16, 16), (4, 4, 16), (4, 4, 4, 4)] + for test_case in test_cases: + test_reshape_after_conv(test_case) + @with_seed() def test_reshape_before_conv():