diff --git a/docs/tutorials/mkldnn/operator_list.md b/docs/tutorials/mkldnn/operator_list.md index 4958f8d9b602..0ef0f29f4cdc 100644 --- a/docs/tutorials/mkldnn/operator_list.md +++ b/docs/tutorials/mkldnn/operator_list.md @@ -44,6 +44,8 @@ To help users understanding MKL-DNN backend better, the following table summariz | **elemwise_add** | 1D-4D input | Y | Y | Y | | **Concat** | 1D-4D input | Y | Y | Y | | **slice** | 1D-4D input | N | Y | N | +| **Reshape** | 1D-4D input | N | Y | N | +| **Flatten** | 1D-4D input | N | Y | N | | **Quantization** | 1D-4D input | N | N | Y | | **Dequantization** | 1D-4D input | N | N | Y | | **Requantization** | 1D-4D input | N | N | Y | diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 5670983e6aa3..e01b7b14082b 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -176,6 +176,7 @@ struct DeconvolutionParam; struct SoftmaxParam; struct SoftmaxOutputParam; struct TransposeParam; +struct ReshapeParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m); @@ -184,6 +185,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); +bool SupportMKLDNNReshape(const ReshapeParam ¶m, const NDArray &data); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc new file mode 100644 index 000000000000..fdc02f960009 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_flatten.cc + * \brief Implement flatten operator by using mkldnn reorder primitive + * \author Wuxun Zhang +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include "mkldnn_reshape-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNFlattenFwd : public MKLDNNReshapeFwd { + public: + explicit MKLDNNFlattenFwd(const OpReqType &req, + const NDArray &input, + const NDArray &output) + : MKLDNNReshapeFwd(req, input, output) {} +}; + +static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req, + const NDArray &input, + const NDArray &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + OpSignature key; + key.AddSign(req); + key.AddSign(input); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNFlattenFwd fwd(req, input, output); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output) { + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; + + auto fwd = GetFlattenForward(req, input, output); + auto ws_size = fwd.GetWorkspaceSize(); + void* ws_ptr = nullptr; + if (ws_size) { + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor ws = ctx.requested[0] + .get_space_typed(mshadow::Shape1(ws_size), s); + ws_ptr = reinterpret_cast(ws.dptr_); + } + + fwd.Execute(input, output, ws_ptr); +} + +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 2699a02f9b9b..502abff6231b 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -119,12 +119,17 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpReqType &req, const NDArray &output); -void MKLDNNReshapeForward(const nnvm::NodeAttrs &attrs, +void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &data, + const NDArray &input, const OpReqType &req, const NDArray &output); +void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h new file mode 100644 index 000000000000..63e367b4dc7f --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkldnn_reshape-inl.h + * \brief Function definition of mkldnn reshape operator + */ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ + +#if MXNET_USE_MKLDNN == 1 +#include +#include "mkldnn_base-inl.h" +#include "../../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNReshapeFwd { + protected: + std::shared_ptr data_; + std::shared_ptr out_; + std::shared_ptr temp_; + std::vector prims_; + bool needInvalidateInput = false; + + public: + MKLDNNReshapeFwd(const OpReqType &req, + const NDArray &input, + const NDArray &output); + int GetWorkspaceSize(); + void SetNewMem(const NDArray &input, + const NDArray &output, + void* workspace = nullptr); + void Execute(const NDArray &input, + const NDArray &output, + void* workspace = nullptr); +}; + +typedef ParamOpSign MKLDNNReshapeSignature; +MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, + const OpReqType &req, + const NDArray &input, + const NDArray &output); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index 4f1d67a8ff9e..063c85dae39a 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -26,7 +26,7 @@ #if MXNET_USE_MKLDNN == 1 #include -#include "../../tensor/matrix_op-inl.h" +#include "mkldnn_reshape-inl.h" namespace mxnet { namespace op { @@ -43,117 +43,106 @@ bool SupportMKLDNNReshape(const ReshapeParam ¶m, return true; } -typedef ParamOpSign MKLDNNReshapeSignature; - -class MKLDNNReshapeForward { - std::shared_ptr data_; - std::shared_ptr out_; - std::shared_ptr temp_; - std::vector prims_; - - bool needInvalidateInput = false; - - public: - MKLDNNReshapeForward(const ReshapeParam ¶m, - const OpReqType &req, - const NDArray &input, - const NDArray &output) { - auto engine = CpuEngine::Get()->get_engine(); - - // data_ - auto in_mem = input.GetMKLDNNData(); - auto in_pd = in_mem->get_primitive_desc(); - data_ = std::make_shared(in_pd, nullptr); - - // temp_ - auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); - auto temp_type = static_cast(in_pd.desc().data.data_type); - auto temp_fmt = static_cast(GetDefaultFormat(in_pd.desc())); - auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); - auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); - temp_ = std::make_shared(temp_pd, nullptr); - - // destination - out_ = std::make_shared(temp_pd, nullptr); - - if (req == kWriteInplace) { - // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with - // default layout and copy from the temporal buffer back to output buffer which has the same - // address with input buffer. - // If the input has default layout, then nothing need to do. - if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back - needInvalidateInput = true; - } - } else if (req == kWriteTo) { - if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer - needInvalidateInput = false; - } else { - prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output - needInvalidateInput = false; - } +MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, + const NDArray &input, + const NDArray &output) { + auto engine = CpuEngine::Get()->get_engine(); + + // data_ + auto in_mem = input.GetMKLDNNData(); + auto in_pd = in_mem->get_primitive_desc(); + data_ = std::make_shared(in_pd, nullptr); + + // temp_ + auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); + auto temp_type = static_cast(in_pd.desc().data.data_type); + auto temp_fmt = static_cast(GetDefaultFormat(in_pd.desc())); + auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); + auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); + temp_ = std::make_shared(temp_pd, nullptr); + + // destination + out_ = std::make_shared(temp_pd, nullptr); + + if (req == kWriteInplace) { + // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with + // default layout and copy from the temporal buffer back to output buffer which has the same + // address with input buffer. + // If the input has default layout, then nothing need to do. + if (input.IsMKLDNNData()) { + prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back + needInvalidateInput = true; + } + } else if (req == kWriteTo) { + if (input.IsMKLDNNData()) { + prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer + needInvalidateInput = false; } else { - LOG(FATAL) << "not supported req type: " << req; + prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output + needInvalidateInput = false; } + } else { + LOG(FATAL) << "not supported req type: " << req; } +} - int GetWorkspaceSize() { - return temp_ ? temp_->get_primitive_desc().get_size() : 0; - } +int MKLDNNReshapeFwd::GetWorkspaceSize() { + return temp_ ? temp_->get_primitive_desc().get_size() : 0; +} - void SetNewMem(const NDArray &input, const NDArray &output, void* workspace = nullptr) { - if (input.IsMKLDNNData()) { - this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle()); - } else { - MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, { - this->data_->set_data_handle(input.data().dptr()); - }) - } +void MKLDNNReshapeFwd::SetNewMem(const NDArray &input, + const NDArray &output, + void* workspace) { + if (input.IsMKLDNNData()) { + this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle()); + } else { + MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, { + this->data_->set_data_handle(input.data().dptr()); + }) + } - if (output.IsMKLDNNData()) { - this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle()); - } else { - MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, { - this->out_->set_data_handle(output.data().dptr()); - }) - } + if (output.IsMKLDNNData()) { + this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle()); + } else { + MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, { + this->out_->set_data_handle(output.data().dptr()); + }) + } - if (workspace) { - this->temp_->set_data_handle(workspace); - } + if (workspace) { + this->temp_->set_data_handle(workspace); } +} - void Execute(const NDArray &input, - const NDArray &output, - void* workspace = nullptr) { - // set memory handles - SetNewMem(input, output, workspace); - // register primitives - auto stream = MKLDNNStream::Get(); - for (auto &v : this->prims_) { - stream->RegisterPrim(v); - } - stream->Submit(); - // invalidate mkldnn memory in input - if (needInvalidateInput) { - const_cast(input).InvalidateMKLDNNData(); - } +void MKLDNNReshapeFwd::Execute(const NDArray &input, + const NDArray &output, + void* workspace) { + // set memory handles + SetNewMem(input, output, workspace); + // register primitives + auto stream = MKLDNNStream::Get(); + for (auto &v : this->prims_) { + stream->RegisterPrim(v); } -}; + stream->Submit(); + // invalidate mkldnn memory in input + if (needInvalidateInput) { + const_cast(input).InvalidateMKLDNNData(); + } +} -static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param, - const OpReqType &req, - const NDArray &input, - const NDArray &output) { +MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, + const OpReqType &req, + const NDArray &input, + const NDArray &output) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; + MKLDNNReshapeFwd, OpHash> fwds; #else static MX_THREAD_LOCAL std::unordered_map fwds; + MKLDNNReshapeFwd, OpHash> fwds; #endif MKLDNNReshapeSignature key(param); key.AddSign(req); @@ -162,7 +151,7 @@ static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param, auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNReshapeForward fwd(param, req, input, output); + MKLDNNReshapeFwd fwd(req, input, output); it = AddToCache(&fwds, key, fwd); } return it->second; diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index b4abc9f5974a..c2bcb29193a7 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -111,12 +111,13 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + const ReshapeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); // If inputs are supposed to be in MKLDNN format and // MKLDNNsupport the data type or the shape. Then convert // it to the output format and shape - if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape())) { + if (SupportMKLDNNReshape(param, inputs[0])) { MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); return; } @@ -233,12 +234,9 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); #if MXNET_USE_MKLDNN == 1 - if (inputs[0].IsMKLDNNData()) { - MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]); - // If the output is a special MKLDNN layout and the number of dimensions - // is larger than 2, we should use the default layout. - if (outputs[0].IsMKLDNNData() && inputs[0].shape().ndim() > 2) - const_cast(outputs[0]).Reorder2Default(); + auto data_ndim = inputs[0].shape().ndim(); + if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) { + MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); return; } else { // This happens if inputs are supposed to be in MKLDNN format @@ -252,10 +250,10 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 064f783ec6c8..5b4f81d96065 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1129,6 +1129,20 @@ def test_pooling_full_2d_type(pool_type): test_pooling_full_2d_type('sum') +@with_seed() +def test_flatten_slice_after_conv(): + ctx_list = [] + + data = mx.sym.Variable('conv_data') + conv = mx.symbol.Convolution(data=data, name='conv', num_filter=16, kernel=(3,3), stride=(1,1)) + flatten = mx.symbol.flatten(data=conv) + slice_sym = mx.symbol.slice(data=flatten, begin=0, end=1) + + ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (2, 16, 16, 16), 'type_dict': {'conv_data': np.float32}}, + {'ctx': mx.cpu(0), 'conv_data': (2, 16, 16, 16), 'type_dict': {'conv_data': np.float32}}] + check_consistency(slice_sym, ctx_list) + + @with_seed() def test_global_pooling(): def test_1d_pooling(pool_type, p_value=2): diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 662edcfeb739..3e623b59977d 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -233,6 +233,26 @@ def hybrid_forward(self, F, x, *args, **kwargs): mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) +@with_seed() +def test_flatten_slice_after_conv(): + data = mx.symbol.Variable('data') + weight = mx.symbol.Variable('weight') + bias = mx.symbol.Variable('bias') + conv1= mx.symbol.Convolution(data = data, weight=weight, bias=bias, name='conv1', num_filter=64, kernel=(3,3), stride=(1,1)) + flatten1 = mx.symbol.flatten(data = conv1) + slice1 = mx.symbol.slice(data = flatten1, begin=0, end=1) + + shape = (2, 16, 16, 16) + val = np.random.rand(2, 16, 16, 16).astype(np.float32) + exe = slice1.simple_bind(Context.default_ctx, data=shape) + exe.arg_arrays[0][:] = val + exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape) + exe.arg_arrays[2][:] = np.random.normal(size=exe.arg_arrays[2].shape) + p = exe.forward(is_train=False) + p[0].wait_to_read() + print(p[0]) + + def test_mkldnn_sum_inplace_with_cpu_layout(): x_shape = (32, 3, 224, 224)