From fe6d085a891b56ecf26af41ea7c01f1c0ecad82a Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Mon, 4 Mar 2019 22:13:16 +0800 Subject: [PATCH 1/8] add mkldnn transpose --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 + src/operator/nn/mkldnn/mkldnn_ops-inl.h | 6 + src/operator/nn/mkldnn/mkldnn_transpose.cc | 210 +++++++++++++++++++++ src/operator/tensor/matrix_op-inl.h | 17 ++ src/operator/tensor/matrix_op.cc | 33 ++++ 5 files changed, 268 insertions(+) create mode 100644 src/operator/nn/mkldnn/mkldnn_transpose.cc diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 0a89c0f31981..a460e33fa548 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -175,12 +175,14 @@ struct ConvolutionParam; struct DeconvolutionParam; struct SoftmaxParam; struct SoftmaxOutputParam; +struct TransposeParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); bool SupportMKLDNNSoftmax(const SoftmaxParam& param); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); +bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 39f26325b2a5..f3f61b457507 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -113,6 +113,12 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, const mkldnn::memory &out); +void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const NDArray &data, + const OpReqType &req, + const NDArray &output); + } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc new file mode 100644 index 000000000000..3dab15495f7a --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_transpose.cc + * \brief + * \author +*/ + +#include +#include +#include +#include +#include "../../operator_common.h" +#include "../../tensor/matrix_op-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 + +#include + +namespace mxnet { +namespace op { + +// for 2D, 01-OI, 10-IO +// for 3D, 012-NCW, 021-NWC +// for 3D, 012-OIW, 210-WIO +// for 4D, 0123-NCHW, 0231-NHWC, 1230-CHWN +// for 4D, 0123-OIHW, 2310-HWIO, 1230-IHWO, 1023-IOHW +std::pair +GetFormatFromAxes(const mxnet::TShape &axes) { + auto src_fmt = mkldnn::memory::format::format_undef; + auto dst_fmt = mkldnn::memory::format::format_undef; + + if (axes.ndim() == 2) { + if (axes == mxnet::TShape({1, 0})) { + src_fmt = mkldnn::memory::format::oi; + dst_fmt = mkldnn::memory::format::io; + } + } else if (axes.ndim() == 3) { + if (axes == mxnet::TShape({0, 2, 1})) { + src_fmt = mkldnn::memory::format::ncw; + dst_fmt = mkldnn::memory::format::nwc; + } else if (axes == mxnet::TShape({2, 1, 0})) { + src_fmt = mkldnn::memory::format::oiw; + dst_fmt = mkldnn::memory::format::wio; + } else { + // do nothing + } + } else if (axes.ndim() == 4) { + if (axes == mxnet::TShape({0, 2, 3, 1})) { + src_fmt = mkldnn::memory::format::nchw; + dst_fmt = mkldnn::memory::format::nhwc; + } else if (axes == mxnet::TShape({1, 2, 3, 0})) { + src_fmt = mkldnn::memory::format::nchw; + dst_fmt = mkldnn::memory::format::chwn; + } else if (axes == mxnet::TShape({2, 3, 1, 0})) { + src_fmt = mkldnn::memory::format::oihw; + dst_fmt = mkldnn::memory::format::hwio; + // } else if (axes == mxnet::TShape({1, 0, 2, 3})) { + // src_fmt = mkldnn::memory::format::oihw; + // dst_fmt = mkldnn::memory::format::iohw; + } else { + // do nothing + } + } else { + // do nothing" + } + + return std::make_pair(src_fmt, dst_fmt); +} + + +bool SupportMKLDNNTranspose(const TransposeParam& param, + const NDArray &data) { + auto data_ndim = data.shape().ndim(); + auto axes_ndim = param.axes.ndim(); + + // currently, we dont support transposion for any internal format + if (data.IsMKLDNNData()) return false; + + auto axes = mxnet::TShape(data_ndim); + if (axes_ndim == 0) { + for (size_t i = 0; i < data_ndim; i++) { + axes[i] = data_ndim - i - 1; + } + } else { + axes = param.axes; + } + + CHECK_EQ(axes.ndim(), data_ndim); + + auto fmt_pair = GetFormatFromAxes(axes); + if (fmt_pair.first == mkldnn::memory::format::format_undef || + fmt_pair.second == mkldnn::memory::format::format_undef) { + return false; + } + + return true; +} + + +typedef ParamOpSign MKLDNNTransposeSignature; + +class MKLDNNTransposeForward { + std::shared_ptr data_; + std::shared_ptr out_; + std::shared_ptr dst_pd_; + std::shared_ptr transpose_; + + public: + MKLDNNTransposeForward(const TransposeParam& param, + const OpReqType &req, + const NDArray &data) { + auto data_ndim = data.shape().ndim(); + auto axes_ndim = param.axes.ndim(); + + auto axes = mxnet::TShape(data_ndim); + if (axes_ndim == 0) { + for (size_t i = 0; i < data_ndim; i++) { + axes[i] = data_ndim - i - 1; + } + } else { + axes = param.axes; + } + + auto fmt_pair = GetFormatFromAxes(axes); + + auto engine = CpuEngine::Get()->get_engine(); + auto dims = mkldnn::memory::dims(data.shape().begin(), data.shape().end()); + auto src_md = mkldnn::memory::desc(dims, get_mkldnn_type(data.dtype()), fmt_pair.first); + auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); + data_ = std::make_shared(src_pd, nullptr); + + auto dst_md = mkldnn::memory::desc(dims, get_mkldnn_type(data.dtype()), fmt_pair.second); + dst_pd_ = std::make_shared(dst_md, engine); + out_ = std::make_shared(*dst_pd_, nullptr); + + transpose_ = std::make_shared(*data_, *out_); + } + + void SetNewMem(const NDArray &data, const NDArray &output) { + MSHADOW_TYPE_SWITCH(data.dtype(), DTYPE, { + this->data_->set_data_handle(data.data().dptr()); + this->out_->set_data_handle(output.data().dptr()); + }); + } + + const mkldnn::reorder &GetFwd() const { + return *transpose_; + } +}; + +static MKLDNNTransposeForward &GetTransposeForward(const TransposeParam& param, + const OpReqType &req, + const NDArray &data) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + MKLDNNTransposeSignature key(param); + key.AddSign(data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNTransposeForward fwd(param, req, data); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const NDArray &data, + const OpReqType &req, + const NDArray &output) { + const TransposeParam& param = nnvm::get(attrs.parsed); + + CHECK_EQ(req, kWriteTo) << "Transpose does not support inplace"; + + auto *stream = MKLDNNStream::Get(); + auto fwd = GetTransposeForward(param, req, data); + + fwd.SetNewMem(data, output); + stream->RegisterPrim(fwd.GetFwd()); + stream->Submit(); +} +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5eecda622729..6a40f69bac76 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -238,6 +238,10 @@ struct TransposeParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape()) .describe("Target axis order. By default the axes will be inverted."); } + + bool operator==(const TransposeParam &other) const { + return this->axes == other.axes; + } }; template @@ -2841,4 +2845,17 @@ inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) { } // namespace op } // namespace mxnet + +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::TransposeParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.axes); + return ret; + } +}; +} // namespace std + + #endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 3bca330f98b0..678fe144f89c 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -339,6 +339,34 @@ Example:: }) .add_argument("data", "NDArray-or-Symbol", "Input array."); +#if MXNET_USE_MKLDNN == 1 +static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const TransposeParam& param = nnvm::get(attrs.parsed); + + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + if (SupportMKLDNNTranspose(param, inputs[0])) { + MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]); + return; + } + FallBackCompute(Transpose, attrs, ctx, inputs, req, outputs); +} + +inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} +#endif + NNVM_REGISTER_OP(transpose) .describe(R"code(Permutes the dimensions of an array. @@ -393,6 +421,11 @@ Examples:: } }) .set_attr("FCompute", Transpose) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", TransposeComputeExCPU) +.set_attr("FInferStorageType", TransposeStorageType) +#endif .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(TransposeParam::__FIELDS__()); From 1f69da22b4624ff5e7fe14ac09ba0de5cf5efd74 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Mon, 25 Mar 2019 13:31:04 +0800 Subject: [PATCH 2/8] general transpose --- src/operator/nn/mkldnn/mkldnn_transpose.cc | 92 +++++++--------------- 1 file changed, 27 insertions(+), 65 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 3dab15495f7a..7368e82dc6f3 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -38,55 +38,6 @@ namespace mxnet { namespace op { -// for 2D, 01-OI, 10-IO -// for 3D, 012-NCW, 021-NWC -// for 3D, 012-OIW, 210-WIO -// for 4D, 0123-NCHW, 0231-NHWC, 1230-CHWN -// for 4D, 0123-OIHW, 2310-HWIO, 1230-IHWO, 1023-IOHW -std::pair -GetFormatFromAxes(const mxnet::TShape &axes) { - auto src_fmt = mkldnn::memory::format::format_undef; - auto dst_fmt = mkldnn::memory::format::format_undef; - - if (axes.ndim() == 2) { - if (axes == mxnet::TShape({1, 0})) { - src_fmt = mkldnn::memory::format::oi; - dst_fmt = mkldnn::memory::format::io; - } - } else if (axes.ndim() == 3) { - if (axes == mxnet::TShape({0, 2, 1})) { - src_fmt = mkldnn::memory::format::ncw; - dst_fmt = mkldnn::memory::format::nwc; - } else if (axes == mxnet::TShape({2, 1, 0})) { - src_fmt = mkldnn::memory::format::oiw; - dst_fmt = mkldnn::memory::format::wio; - } else { - // do nothing - } - } else if (axes.ndim() == 4) { - if (axes == mxnet::TShape({0, 2, 3, 1})) { - src_fmt = mkldnn::memory::format::nchw; - dst_fmt = mkldnn::memory::format::nhwc; - } else if (axes == mxnet::TShape({1, 2, 3, 0})) { - src_fmt = mkldnn::memory::format::nchw; - dst_fmt = mkldnn::memory::format::chwn; - } else if (axes == mxnet::TShape({2, 3, 1, 0})) { - src_fmt = mkldnn::memory::format::oihw; - dst_fmt = mkldnn::memory::format::hwio; - // } else if (axes == mxnet::TShape({1, 0, 2, 3})) { - // src_fmt = mkldnn::memory::format::oihw; - // dst_fmt = mkldnn::memory::format::iohw; - } else { - // do nothing - } - } else { - // do nothing" - } - - return std::make_pair(src_fmt, dst_fmt); -} - - bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data) { auto data_ndim = data.shape().ndim(); @@ -105,13 +56,6 @@ bool SupportMKLDNNTranspose(const TransposeParam& param, } CHECK_EQ(axes.ndim(), data_ndim); - - auto fmt_pair = GetFormatFromAxes(axes); - if (fmt_pair.first == mkldnn::memory::format::format_undef || - fmt_pair.second == mkldnn::memory::format::format_undef) { - return false; - } - return true; } @@ -128,9 +72,9 @@ class MKLDNNTransposeForward { MKLDNNTransposeForward(const TransposeParam& param, const OpReqType &req, const NDArray &data) { - auto data_ndim = data.shape().ndim(); + auto shape = data.shape(); + auto data_ndim = shape.ndim(); auto axes_ndim = param.axes.ndim(); - auto axes = mxnet::TShape(data_ndim); if (axes_ndim == 0) { for (size_t i = 0; i < data_ndim; i++) { @@ -140,16 +84,34 @@ class MKLDNNTransposeForward { axes = param.axes; } - auto fmt_pair = GetFormatFromAxes(axes); - auto engine = CpuEngine::Get()->get_engine(); - auto dims = mkldnn::memory::dims(data.shape().begin(), data.shape().end()); - auto src_md = mkldnn::memory::desc(dims, get_mkldnn_type(data.dtype()), fmt_pair.first); - auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); + auto in_mem = data.GetMKLDNNData(); + auto src_pd = in_mem->get_primitive_desc(); data_ = std::make_shared(src_pd, nullptr); - auto dst_md = mkldnn::memory::desc(dims, get_mkldnn_type(data.dtype()), fmt_pair.second); - dst_pd_ = std::make_shared(dst_md, engine); + // destination + mkldnn_memory_desc_t dst_fmt; + dst_fmt.primitive_kind = mkldnn_memory; + dst_fmt.ndims = data_ndim; + dst_fmt.data_type = mkldnn_f32; + dst_fmt.format = mkldnn_blocked; + + for (size_t i = 0; i < data_ndim; i++) + dst_fmt.dims[i] = shape[i]; + + unsigned int total_stride = 1; + for (int i = data_ndim - 1; i >= 0; i--) { + dst_fmt.layout_desc.blocking.padding_dims[i] = shape[i]; + dst_fmt.layout_desc.blocking.block_dims[i] = 1; + dst_fmt.layout_desc.blocking.offset_padding_to_data[i]= 0; + dst_fmt.layout_desc.blocking.strides[0][axes[i]] = total_stride; + dst_fmt.layout_desc.blocking.strides[1][axes[i]] = 1; + + total_stride *= shape[axes[i]]; + } + + dst_fmt.layout_desc.blocking.offset_padding = 0; + dst_pd_ = std::make_shared(dst_fmt, engine); out_ = std::make_shared(*dst_pd_, nullptr); transpose_ = std::make_shared(*data_, *out_); From 619ae9348a54d445deddd86701b13a511ae81682 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Wed, 27 Mar 2019 22:01:20 +0800 Subject: [PATCH 3/8] support mkldnn format --- src/operator/nn/mkldnn/mkldnn_transpose.cc | 41 +++++++--------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 7368e82dc6f3..f9fb3348b39d 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -23,17 +23,10 @@ * \author */ -#include -#include -#include -#include -#include "../../operator_common.h" -#include "../../tensor/matrix_op-inl.h" -#include "./mkldnn_base-inl.h" - #if MXNET_USE_MKLDNN == 1 #include +#include "../../tensor/matrix_op-inl.h" namespace mxnet { namespace op { @@ -41,25 +34,13 @@ namespace op { bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data) { auto data_ndim = data.shape().ndim(); - auto axes_ndim = param.axes.ndim(); - // currently, we dont support transposion for any internal format - if (data.IsMKLDNNData()) return false; - - auto axes = mxnet::TShape(data_ndim); - if (axes_ndim == 0) { - for (size_t i = 0; i < data_ndim; i++) { - axes[i] = data_ndim - i - 1; - } - } else { - axes = param.axes; - } + if (data_ndim > 4 || data.dtype() != mshadow::kFloat32) + return false; - CHECK_EQ(axes.ndim(), data_ndim); return true; } - typedef ParamOpSign MKLDNNTransposeSignature; class MKLDNNTransposeForward { @@ -118,10 +99,14 @@ class MKLDNNTransposeForward { } void SetNewMem(const NDArray &data, const NDArray &output) { - MSHADOW_TYPE_SWITCH(data.dtype(), DTYPE, { - this->data_->set_data_handle(data.data().dptr()); - this->out_->set_data_handle(output.data().dptr()); - }); + if (data.IsMKLDNNData()) { + this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle()); + } else { + this->data_->set_data_handle(data.data().dptr()); + } + + CHECK(! output.IsMKLDNNData()); + this->out_->set_data_handle(output.data().dptr()); } const mkldnn::reorder &GetFwd() const { @@ -156,10 +141,9 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpReqType &req, const NDArray &output) { const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req, kWriteTo) << "Transpose does not support inplace"; - auto *stream = MKLDNNStream::Get(); + auto stream = MKLDNNStream::Get(); auto fwd = GetTransposeForward(param, req, data); fwd.SetNewMem(data, output); @@ -168,5 +152,4 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, } } // namespace op } // namespace mxnet - #endif From 5193eae68cfe66d44c92510c11c92eda2f3701e2 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Wed, 27 Mar 2019 22:35:46 +0800 Subject: [PATCH 4/8] fix lint --- src/operator/nn/mkldnn/mkldnn_transpose.cc | 4 ++-- src/operator/tensor/matrix_op-inl.h | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index f9fb3348b39d..6c2fd63c3573 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -100,12 +100,12 @@ class MKLDNNTransposeForward { void SetNewMem(const NDArray &data, const NDArray &output) { if (data.IsMKLDNNData()) { - this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle()); + this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle()); } else { this->data_->set_data_handle(data.data().dptr()); } - CHECK(! output.IsMKLDNNData()); + CHECK(!output.IsMKLDNNData()); this->out_->set_data_handle(output.data().dptr()); } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 6a40f69bac76..fa108158b5c9 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -240,7 +240,7 @@ struct TransposeParam : public dmlc::Parameter { } bool operator==(const TransposeParam &other) const { - return this->axes == other.axes; + return this->axes == other.axes; } }; @@ -2845,7 +2845,6 @@ inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) { } // namespace op } // namespace mxnet - namespace std { template<> struct hash { @@ -2857,5 +2856,4 @@ struct hash { }; } // namespace std - #endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ From bcbb971e6803dea933f883368b308a2a143733c9 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Thu, 28 Mar 2019 14:40:38 +0800 Subject: [PATCH 5/8] address comments --- src/operator/nn/mkldnn/mkldnn_transpose.cc | 15 ++++++++------- src/operator/tensor/matrix_op.cc | 3 ++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 6c2fd63c3573..9c9a2a5d24f6 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -51,7 +51,6 @@ class MKLDNNTransposeForward { public: MKLDNNTransposeForward(const TransposeParam& param, - const OpReqType &req, const NDArray &data) { auto shape = data.shape(); auto data_ndim = shape.ndim(); @@ -102,11 +101,15 @@ class MKLDNNTransposeForward { if (data.IsMKLDNNData()) { this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle()); } else { - this->data_->set_data_handle(data.data().dptr()); + MSHADOW_TYPE_SWITCH(data.dtype(), DTYPE, { + this->data_->set_data_handle(data.data().dptr()); + }); } CHECK(!output.IsMKLDNNData()); - this->out_->set_data_handle(output.data().dptr()); + MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, { + this->out_->set_data_handle(output.data().dptr()); + }); } const mkldnn::reorder &GetFwd() const { @@ -115,7 +118,6 @@ class MKLDNNTransposeForward { }; static MKLDNNTransposeForward &GetTransposeForward(const TransposeParam& param, - const OpReqType &req, const NDArray &data) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_mapsecond; @@ -141,10 +143,9 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpReqType &req, const NDArray &output) { const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req, kWriteTo) << "Transpose does not support inplace"; auto stream = MKLDNNStream::Get(); - auto fwd = GetTransposeForward(param, req, data); + auto fwd = GetTransposeForward(param, data); fwd.SetNewMem(data, output); stream->RegisterPrim(fwd.GetFwd()); diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 678fe144f89c..1431fef13594 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -346,9 +346,10 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const TransposeParam& param = nnvm::get(attrs.parsed); - + CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); + if (SupportMKLDNNTranspose(param, inputs[0])) { MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]); return; From 3a06c9d2b346b6120da9e078fcb395a8442a8167 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sun, 31 Mar 2019 22:41:35 +0800 Subject: [PATCH 6/8] add unit test --- tests/python/mkl/test_mkldnn.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 01ba03cab7cd..0610b606c201 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -473,6 +473,21 @@ def backward(self, req, out_grad, in_data, out_data, in_grad, aux): exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 'conv_weight': mx.nd.ones([8,3,5,5])}) exec1.forward()[0].wait_to_read() +@with_seed() +def test_conv_transpose(): + axes = [(0,2,1,3), (0,2,3,1), (1,2,3,0), (3,2,1,0)] + a = np.random.rand(10, 16, 50, 50) + b = np.random.rand(32, 16, 3, 3) + x = mx.nd.array(a) + w = mx.nd.array(b) + y = mx.nd.Convolution(data=x, weight=w, kernel=(3, 3), num_group=1, num_filter=32, no_bias=True) + for axis in axes: + t = mx.nd.transpose(y, axis) + t.wait_to_read() + s = y.asnumpy() + n = np.transpose(s, axis) + np.allclose(t.asnumpy(), n) + if __name__ == '__main__': install.test_mkldnn_install() From 40334b97388bcfd3860d327f2a01ba6f3e49d7b9 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Sun, 31 Mar 2019 23:32:21 +0800 Subject: [PATCH 7/8] add comments --- src/operator/nn/mkldnn/mkldnn_transpose.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_transpose.cc b/src/operator/nn/mkldnn/mkldnn_transpose.cc index 9c9a2a5d24f6..0986d0616f75 100644 --- a/src/operator/nn/mkldnn/mkldnn_transpose.cc +++ b/src/operator/nn/mkldnn/mkldnn_transpose.cc @@ -19,8 +19,8 @@ /*! * \file mkldnn_transpose.cc - * \brief - * \author + * \brief Implement transpose operator via MKL-DNN reorder primitive + * \author Tao Lv */ #if MXNET_USE_MKLDNN == 1 @@ -70,6 +70,9 @@ class MKLDNNTransposeForward { data_ = std::make_shared(src_pd, nullptr); // destination + // Not all formats are well defined with a certain name in MKL-DNN. + // For example, transpose(NCHW, (0, 2, 1, 3)) -> NHCW, which is not explicitly defined in + // MKL-DNN. To support general transposing, we need create destination format from scratch. mkldnn_memory_desc_t dst_fmt; dst_fmt.primitive_kind = mkldnn_memory; dst_fmt.ndims = data_ndim; @@ -84,7 +87,9 @@ class MKLDNNTransposeForward { dst_fmt.layout_desc.blocking.padding_dims[i] = shape[i]; dst_fmt.layout_desc.blocking.block_dims[i] = 1; dst_fmt.layout_desc.blocking.offset_padding_to_data[i]= 0; + // strides[0]: stride between the first elements of adjacent blocks. dst_fmt.layout_desc.blocking.strides[0][axes[i]] = total_stride; + // strides[1]: strides between elements in the same block. dst_fmt.layout_desc.blocking.strides[1][axes[i]] = 1; total_stride *= shape[axes[i]]; From ef218de553c8f841ec310c5c8c317c6390de4151 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Sun, 7 Apr 2019 07:49:13 +0800 Subject: [PATCH 8/8] retrigger CI