From 1ff942948599dd248446fcb610b9fe0cc3070580 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Wed, 18 Sep 2019 11:09:03 +0800 Subject: [PATCH] [mkldnn-v1.0] Add MKL-DNN Convolution (#16141) * add mkldnn conv * revert unnecessary change * fix testcase fail for cpu: test_convolution_independent_gradients * fix failed testcase: test_reshape_transpose_6d&&test_weight_async_reorder * fix comments * change variable name from weights to weight in mkldnn_conv --- include/mxnet/ndarray.h | 4 +- src/common/exec_utils.h | 8 +- src/executor/attach_op_execs_pass.cc | 8 +- src/imperative/imperative_utils.h | 20 +- src/ndarray/ndarray.cc | 16 +- src/operator/nn/convolution.cc | 28 +- src/operator/nn/mkldnn/mkldnn_base-inl.h | 22 + src/operator/nn/mkldnn/mkldnn_base.cc | 3 +- .../nn/mkldnn/mkldnn_convolution-inl.h | 67 ++- src/operator/nn/mkldnn/mkldnn_convolution.cc | 539 ++++++------------ src/operator/nn/mkldnn/mkldnn_ops-inl.h | 21 +- src/operator/operator_common.h | 41 +- src/operator/tensor/cast_storage-inl.h | 6 +- 13 files changed, 349 insertions(+), 434 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index fc4375b493a7..16bb7e40b7a2 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -761,8 +761,8 @@ class NDArray { * It changes the layout of this NDArray, but it happens after all accesses to * the array are complete. */ - void Reorder2DefaultAsync(); - void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md); + void Reorder2DefaultAsync() const; + void MKLDNNDataReorderAsync(const mkldnn::memory::desc &md) const; /* * This creates a new NDArray with the reordered data. diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index d8b7a33bf22b..f0b29e75147e 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -59,7 +59,7 @@ inline bool SetupDefaultBlobsIn(const std::vector& src, for (size_t i = 0; i < src.size(); i++) { auto& nd = src[i]; bool is_default = nd.storage_type() == kDefaultStorage; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 // We have to make sure it's default storage and default layout. is_default = nd.IsDefaultData(); #endif @@ -67,7 +67,7 @@ inline bool SetupDefaultBlobsIn(const std::vector& src, (*idx_map)[i] = temp_dst->size(); NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 CHECK(temp.IsDefaultData()); #endif temp_src->emplace_back(nd); @@ -91,7 +91,7 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, for (size_t i = 0; i < src.size(); i++) { auto& nd = src[i]; bool is_default = nd.storage_type() == kDefaultStorage; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (req->at(i) == kWriteInplace && nd.IsMKLDNNData()) // If it's write inplace and the output array doesn't use the default // layout, we'll generate a temporary output array below, which means @@ -102,7 +102,7 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, is_default = nd.IsDefaultData(); #endif if (!is_default) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 NDArray temp; if (bufs != nullptr) { temp = bufs->at(i); diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 8f47bc29db13..ebd032830e3c 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -116,7 +116,7 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 InvalidateOutputs(out_array, req); #endif PreFCompute(is_gpu); @@ -155,7 +155,7 @@ class StatefulComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 InvalidateOutputs(out_array, req); // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); @@ -202,7 +202,7 @@ class FComputeExecutor : public StorageFallbackOpExecutor { void Run(RunContext rctx, bool is_gpu) override { using namespace common; op_ctx.run_ctx = rctx; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 InvalidateOutputs(out_array, req); #endif PreFCompute(is_gpu); @@ -231,7 +231,7 @@ class FComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 InvalidateOutputs(out_array, req); // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 21caafa124f9..3a2875ea7eb3 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -418,7 +418,7 @@ inline void PushFCompute(const FCompute& fn, std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; // mapping from index in input_blobs to index in pre_temp_dst std::unordered_map in_temp_idx_map; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (exec_type != ExecType::kCrossDeviceCopy) { // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in // its FCcomputeEx, but AsyncPush the copy operation to engine. @@ -467,7 +467,7 @@ inline void PushFComputeEx(const FComputeEx& fn, DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (exec_type != ExecType::kCrossDeviceCopy) { // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in // its FCcomputeEx, but AsyncPush the copy operation to engine. @@ -476,8 +476,18 @@ inline void PushFComputeEx(const FComputeEx& fn, // copying A to B may not happen, and will corrupt A's memory. InvalidateOutputs(outputs, req); } + // add for mkldnn OP + no mkldnn OP + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); + if (!is_mkldnn.get(attrs.op, false)) { + std::vector inputs_fallback; + CreateDefaultInputs(inputs, &inputs_fallback); + fn(attrs, opctx, inputs_fallback, req, outputs); + } else { +#endif + fn(attrs, opctx, inputs, req, outputs); +#if MXNET_USE_MKLDNN == 100 + } #endif - fn(attrs, opctx, inputs, req, outputs); if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) { rctx.get_stream()->Wait(); } @@ -521,7 +531,7 @@ inline void PushOperator(const OpStatePtr& state, const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (exec_type != ExecType::kCrossDeviceCopy) { // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in // its FCcomputeEx, but AsyncPush the copy operation to engine. @@ -567,7 +577,7 @@ inline void PushOperator(const OpStatePtr& state, std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; // mapping from index in input_blobs to index in pre_temp_dst std::unordered_map in_temp_idx_map; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (exec_type != ExecType::kCrossDeviceCopy) { // kCrossDeviceCopy is used for `_copy_to` operator, which doesn't compute immediately in // its FCcomputeEx, but AsyncPush the copy operation to engine. diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 97daa29492f6..f174d75553b0 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -474,7 +474,7 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) { mkldnn::memory::dims dims; // These are shapes supprted by MKLDNN. - if (shape.ndim() >= 1 && shape.ndim() <= 5) { + if (shape.ndim() >= 1 && shape.ndim() <= 6) { dims.resize(shape.ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = shape[i]; @@ -488,6 +488,7 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape &shape, int dtype) { case 3: layout = mkldnn::memory::format_tag::abc; break; case 4: layout = mkldnn::memory::format_tag::abcd; break; case 5: layout = mkldnn::memory::format_tag::abcde; break; + case 6: layout = mkldnn::memory::format_tag::abcdef; break; default: LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for MKLDNN"; } @@ -592,7 +593,7 @@ NDArray NDArray::Reorder2Default() const { return ret; } -void NDArray::Reorder2DefaultAsync() { +void NDArray::Reorder2DefaultAsync() const { std::vector const_vars; std::vector mutable_vars(1, this->var()); NDArray tmp = *this; @@ -604,13 +605,18 @@ void NDArray::Reorder2DefaultAsync() { FnProperty::kNormal, 0, "Reorder2Default"); } -void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc &desc) { +void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::desc &desc) const { std::vector const_vars; std::vector mutable_vars(1, this->var()); NDArray tmp = *this; + const auto version = this->version(); Engine::Get()->PushAsync( - [tmp, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { - tmp.ptr_->MKLDNNDataReorder(desc); + [tmp, version, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) { + // MXNet will try to reuse NDArray from memory planning, so we need to ensure + // the NDArray is still holding the original trunk data. + if (tmp.version() == version) { + tmp.ptr_->MKLDNNDataReorder(desc); + } on_complete(); }, ctx(), const_vars, mutable_vars, FnProperty::kNormal, 0, "Reorder"); diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 32ed93e4a463..ad191283d87a 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -30,7 +30,7 @@ #if MXNET_USE_NNPACK == 1 #include "../nnpack/nnpack_pooling-inl.h" #endif // MXNET_USE_NNPACK -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_base-inl.h" #include "./mkldnn/mkldnn_ops-inl.h" #endif // MXNET_USE_MKLDNN @@ -51,7 +51,7 @@ static inline std::vector ListArguments(const ConvolutionParam& par } } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -60,7 +60,12 @@ static void ConvolutionComputeExCPU(const nnvm::NodeAttrs& attrs, const ConvolutionParam& params = nnvm::get(attrs.parsed); if (SupportMKLDNNConv(params, inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); - MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs); + if (CheckMKLDNNInputArrayIsView(inputs)) { + const auto mkldnn_inputs = GetMKLDNNInputArray(inputs); + MKLDNNConvolutionForward(attrs, ctx, mkldnn_inputs, req, outputs); + } else { + MKLDNNConvolutionForward(attrs, ctx, inputs, req, outputs); + } MKLDNN_OPCHECK_RUN(ConvolutionCompute, attrs, ctx, inputs, req, outputs); return; } @@ -75,7 +80,12 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, const ConvolutionParam& params = nnvm::get(attrs.parsed); if (SupportMKLDNNConv(params, inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs); + if (CheckMKLDNNInputArrayIsView(inputs)) { + const auto mkldnn_inputs = GetMKLDNNInputArray(inputs); + MKLDNNConvolutionBackward(attrs, ctx, mkldnn_inputs, req, outputs); + } else { + MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs); + } MKLDNN_OPCHECK_RUN(ConvolutionGradCompute, attrs, ctx, inputs, req, outputs); return; } @@ -302,7 +312,7 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, return true; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 inline static bool ConvStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -491,11 +501,11 @@ There are other options to tune the performance. }) .set_attr("FInferShape", ConvolutionShape) .set_attr("FInferType", ConvolutionType) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", ConvStorageType) #endif .set_attr("FCompute", ConvolutionCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ConvolutionComputeExCPU) #endif @@ -514,14 +524,14 @@ NNVM_REGISTER_OP(_backward_Convolution) return params.no_bias ? 2 : 3; }) .set_attr("TIsBackward", true) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", BackwardConvStorageType) #endif .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) .set_attr_parser(ConvolutionParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ConvolutionGradComputeExCPU) #endif diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 85d42ff48d35..054f422deb8f 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -277,6 +277,28 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, } } +inline static bool CheckMKLDNNInputArrayIsView(const std::vector &inputs) { + for (const auto &in : inputs) { + if (in.IsView() && in.IsMKLDNNData()) { + return true; + } + } + return false; +} + +inline static const std::vector GetMKLDNNInputArray(const std::vector &inputs) { + std::vector ret; + ret.reserve(inputs.size()); + for (const auto &in : inputs) { + if (in.IsView() && in.IsMKLDNNData()) { + ret.push_back(in.Reorder2Default()); + } else { + ret.push_back(in); + } + } + return ret; +} + typedef std::shared_ptr mkldnn_mem_ptr; typedef std::shared_ptr mkldnn_mem_const_ptr; diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 31ffbbb471c4..cfd7ad71a31d 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -312,6 +312,7 @@ mkldnn_format_tag_t GetDefaultFormat(int num_dims) { case 3: return mkldnn_abc; case 4: return mkldnn_abcd; case 5: return mkldnn_abcde; + case 6: return mkldnn_abcdef; default: LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for MKLDNN"; return mkldnn_format_tag_undef; @@ -530,7 +531,7 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs, if (v == - 1) v = kDefaultStorage; DispatchMode wanted_mode; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (dev_mask == mshadow::cpu::kDevMask && !MKLDNNEnvSet()) wanted_mode = DispatchMode::kFComputeFallback; else if (dev_mask == mshadow::cpu::kDevMask && support_mkldnn) diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index 880b9d19cd81..b3ceb9595730 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -25,7 +25,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include @@ -79,47 +79,26 @@ struct MKLDNNConvFullParam { MKLDNNPostEltwiseParam postsum_act_param; }; -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, - const bool is_train, - const NDArray &data, - const NDArray &weights, - const NDArray *bias, - const NDArray &output); +std::shared_ptr GetConvFwdImpl( + const ConvolutionParam ¶m, const bool is_train, const NDArray &data, const NDArray &weight, + const NDArray *bias, const NDArray &output); class MKLDNNConvForward { public: - mkldnn::convolution_forward::primitive_desc fwd_pd; - MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output); + const NDArray &weight, const NDArray *bias, const NDArray &output); - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output); + const mkldnn::convolution_forward &GetFwd() const { return *fwd_; } - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - this->data_->set_data_handle(data.get_data_handle()); - this->out_->set_data_handle(output.get_data_handle()); - } - - const mkldnn::convolution_forward &GetFwd() const { - return *fwd_; - } + const mkldnn::convolution_forward::primitive_desc &GetPd() const { return *pd_; } private: std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr weight_; - std::shared_ptr bias_; - std::shared_ptr out_; + std::shared_ptr pd_; }; typedef ParamOpSign MKLDNNConvSignature; -MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, - const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, - const NDArray &output); - void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, MKLDNNConvForward *fwd, @@ -127,6 +106,36 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const std::vector &req, const std::vector &out_data); +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + +class MKLDNNConvBackward { + public: + MKLDNNConvBackward(const MKLDNNConvFullParam ¶m, const NDArray &data, const NDArray &weight, + const NDArray *bias, const NDArray &output); + + const mkldnn::convolution_backward_data &GetBwdData() const { return *bwd_data_; } + + const mkldnn::convolution_backward_weights &GetBwdWeights() const { return *bwd_weight_; } + + const mkldnn::convolution_backward_data::primitive_desc &GetDataPd() const { + return *bwd_data_pd_; + } + + const mkldnn::convolution_backward_weights::primitive_desc &GetWeightsPd() const { + return *bwd_weight_pd_; + } + + private: + std::shared_ptr bwd_data_pd_; + std::shared_ptr bwd_weight_pd_; + std::shared_ptr bwd_data_; + std::shared_ptr bwd_weight_; +}; + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 9cab2dd0e2b3..41141884a08b 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -21,10 +21,9 @@ * \file mkldnn_convolution.cc * \brief * \author Da Zheng -*/ - + */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../convolution-inl.h" #include "./mkldnn_ops-inl.h" @@ -45,8 +44,10 @@ bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) { (input.shape().ndim() == 4)); } -mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam ¶m, - const bool is_train, const NDArray &data, +std::shared_ptr GetConvFwdImpl( + const MKLDNNConvFullParam ¶m, + const bool is_train, + const NDArray &data, const NDArray &weights, const NDArray *bias, const NDArray &output) { @@ -57,7 +58,7 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP auto bias_md = bias ? (param.mkldnn_param.quantized ? GetMemDesc(*bias, mshadow::kInt32) : GetMemDesc(*bias)) : mkldnn::memory::desc{ - {}, mkldnn::memory::data_type::data_undef, mkldnn::memory::format::any}; + {}, mkldnn::memory::data_type::undef, mkldnn::memory::format_tag::any}; auto bias_md_ptr = bias ? &bias_md : nullptr; mkldnn::memory::dims strides(param.conv_param.kernel.ndim()); @@ -98,19 +99,19 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP if (param.mkldnn_param.quantized && param.requantize_scales.size()) { int mask = (param.requantize_scales.size() > 1) ? 2 : 0; attr.set_output_scales(mask, param.requantize_scales); - attr.set_int_output_round_mode(round_nearest); } auto GetConvFwdPd = [¶m, &data, &weights, &output, &attr](const mkldnn::convolution_forward::desc &desc) { auto engine = CpuEngine::Get()->get_engine(); try { - auto conv_pd = mkldnn::convolution_forward::primitive_desc(desc, attr, engine); - while (conv_pd.dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || + auto conv_pd = + std::make_shared(desc, attr, engine); + while (conv_pd->dst_desc().get_size() != GetArraySize(output) || + conv_pd->src_desc().get_size() != GetArraySize(data) || (!param.mkldnn_param.quantized && - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights))) { + conv_pd->weights_desc().get_size() != GetArraySize(weights))) { // next_impl() will visit desc and engine, please make sure they are still alive here. - CHECK(conv_pd.next_impl()) << "No convolution implementation for this request."; + CHECK(conv_pd->next_impl()) << "No convolution implementation for this request."; } return conv_pd; } catch (mkldnn::error &e) { @@ -126,13 +127,12 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP if (param.conv_param.dilate.ndim() == 0 && bias_md_ptr == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, - weight_md, out_md, strides, padding, padding, - mkldnn::padding_kind::zero); + weight_md, out_md, strides, padding, padding); return GetConvFwdPd(desc); } else if (param.conv_param.dilate.ndim() == 0) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, *bias_md_ptr, out_md, strides, padding, - padding, mkldnn::padding_kind::zero); + padding); return GetConvFwdPd(desc); } else { mkldnn::memory::dims dilates(param.conv_param.kernel.ndim()); @@ -147,23 +147,22 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP } if (bias_md_ptr == nullptr) { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, - weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); + weight_md, out_md, strides, dilates, padding, padding); return GetConvFwdPd(desc); } else { mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md, weight_md, *bias_md_ptr, out_md, strides, dilates, - padding, padding, mkldnn::padding_kind::zero); + padding, padding); return GetConvFwdPd(desc); } } } -static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( - const ConvolutionParam& param, const NDArray &data, const NDArray &weights, +static std::shared_ptr GetConvBwdData( + const ConvolutionParam ¶m, const NDArray &data, const NDArray &weight, const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); + auto weight_md = GetWeightDesc(weight, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::memory::dims strides(param.kernel.ndim()); @@ -187,21 +186,29 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( << ", supporting only 1 or 2."; } - // MKL-DNN introduced padded formats since 0.15 which require more memory - // for computation compared with the actual tensor size. Currently, MKL-DNN - // operators are still reusing those memory from memory planning and the - // memory size may smaller than what MKL-DNN kernels require. So here we need - // select suboptimal kernel for computation according to tensor sizes. - if (param.dilate.ndim() == 0) { - mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.diff_src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; + auto GetConvBwdDataPd = [&data, &weight, &output, + &fwd_pd](const mkldnn::convolution_backward_data::desc &desc) { + auto engine = CpuEngine::Get()->get_engine(); + try { + auto conv_pd = + std::make_shared(desc, engine, fwd_pd); + while (conv_pd->diff_dst_desc().get_size() != GetArraySize(output) || + conv_pd->diff_src_desc().get_size() != GetArraySize(data) || + conv_pd->weights_desc().get_size() != GetArraySize(weight)) { + // next_impl() will visit desc and engine, please make sure they are still alive here. + CHECK(conv_pd->next_impl()) << "No convolution backward implementation for this request."; + } + return conv_pd; + } catch (mkldnn::error &e) { + LOG(ERROR) << e.message; + throw; } - return conv_pd; + }; + + if (param.dilate.ndim() == 0) { + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, padding, padding); + return GetConvBwdDataPd(desc); } else { mkldnn::memory::dims dilates(param.kernel.ndim()); if (param.dilate.ndim() == 1) { @@ -213,25 +220,18 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData( LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << param.dilate.ndim() << ", supporting only 1 or 2."; } - mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_data::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.diff_src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, dilates, padding, + padding); + return GetConvBwdDataPd(desc); } } -static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( - const ConvolutionParam& param, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd) { +static std::shared_ptr GetConvBwdWeights( + const ConvolutionParam ¶m, const NDArray &data, const NDArray &weight, const NDArray *bias, + const NDArray &output, const mkldnn::convolution_forward::primitive_desc &fwd_pd) { auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); + auto weight_md = GetWeightDesc(weight, param.num_group); auto out_md = GetMemDesc(output); auto engine = CpuEngine::Get()->get_engine(); mkldnn::memory::dims strides(param.kernel.ndim()); @@ -255,33 +255,35 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( << ", supporting only 1 or 2."; } - // MKL-DNN introduced padded formats since 0.15 which require more memory - // for computation compared with the actual tensor size. Currently, MKL-DNN - // operators are still reusing those memory from memory planning and the - // memory size may smaller than what MKL-DNN kernels require. So here we need - // select suboptimal kernel for computation according to tensor sizes. - if (param.dilate.ndim() == 0 && bias == nullptr) { - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, padding, padding, mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; + auto GetConvBwdWeightsPd = [&data, &weight, &output, + &fwd_pd](const mkldnn::convolution_backward_weights::desc &desc) { + auto engine = CpuEngine::Get()->get_engine(); + try { + auto conv_pd = std::make_shared( + desc, engine, fwd_pd); + while (conv_pd->diff_dst_desc().get_size() != GetArraySize(output) || + conv_pd->src_desc().get_size() != GetArraySize(data) || + conv_pd->diff_weights_desc().get_size() != GetArraySize(weight)) { + // next_impl() will visit desc and engine, please make sure they are still alive here. + CHECK(conv_pd->next_impl()) << "No convolution backward implementation for this request."; + } + return conv_pd; + } catch (mkldnn::error &e) { + LOG(ERROR) << e.message; + throw; } - return conv_pd; + }; + + if (param.dilate.ndim() == 0 && bias == nullptr) { + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, out_md, strides, padding, padding); + return GetConvBwdWeightsPd(desc); } else if (param.dilate.ndim() == 0) { auto bias_md = GetMemDesc(*bias); - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, bias_md, out_md, strides, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, data_md, + weight_md, bias_md, out_md, strides, padding, + padding); + return GetConvBwdWeightsPd(desc); } else { mkldnn::memory::dims dilates(param.kernel.ndim()); if (param.dilate.ndim() == 1) { @@ -295,313 +297,154 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights( } if (bias == nullptr) { mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, out_md, strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + data_md, weight_md, out_md, strides, dilates, + padding, padding); + return GetConvBwdWeightsPd(desc); } else { auto bias_md = GetMemDesc(*bias); mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - data_md, weight_md, bias_md, out_md, - strides, dilates, padding, padding, - mkldnn::padding_kind::zero); - auto conv_pd = mkldnn::convolution_backward_weights::primitive_desc(desc, engine, fwd_pd); - while (conv_pd.diff_dst_primitive_desc().get_size() != GetArraySize(output) || - conv_pd.src_primitive_desc().get_size() != GetArraySize(data) || - conv_pd.diff_weights_primitive_desc().get_size() != GetArraySize(weights)) { - CHECK(conv_pd.next_impl()) << "No implementation"; - } - return conv_pd; + data_md, weight_md, bias_md, out_md, strides, + dilates, padding, padding); + return GetConvBwdWeightsPd(desc); } } } MKLDNNConvForward::MKLDNNConvForward(const MKLDNNConvFullParam ¶m, const bool is_train, - const NDArray &data, const NDArray &weights, + const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output) - : fwd_pd(GetConvFwdImpl(param, is_train, data, weights, bias, output)) { - data_ = std::make_shared(fwd_pd.src_primitive_desc(), nullptr); - weight_ = std::make_shared(fwd_pd.weights_primitive_desc(), nullptr); - out_ = std::make_shared(fwd_pd.dst_primitive_desc(), nullptr); - if (bias) { - bias_ = std::make_shared(fwd_pd.bias_primitive_desc(), nullptr); - fwd_ = std::make_shared(fwd_pd, *this->data_, *this->weight_, - *this->bias_, *this->out_); - } else { - fwd_ = std::make_shared(fwd_pd, *this->data_, *this->weight_, - *this->out_); - } + : pd_(GetConvFwdImpl(param, is_train, data, weight, bias, output)) { + fwd_ = std::make_shared(GetPd()); } -void MKLDNNConvForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, - const mkldnn::memory *bias, const mkldnn::memory &output) { - data_->set_data_handle(data.get_data_handle()); - weight_->set_data_handle(weight.get_data_handle()); - out_->set_data_handle(output.get_data_handle()); - if (bias != nullptr) bias_->set_data_handle(bias->get_data_handle()); -} - -MKLDNNConvForward &GetConvFwd(const ConvolutionParam ¶m, - const bool is_train, const NDArray &data, - const NDArray &weights, const NDArray *bias, +MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam ¶m, const bool is_train, + const NDArray &data, const NDArray &weight, const NDArray *bias, const NDArray &output) { + using conv_fwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; + static thread_local conv_fwd_map fwds; #else - static MX_THREAD_LOCAL std::unordered_map fwds; + static MX_THREAD_LOCAL conv_fwd_map fwds; #endif - MKLDNNConvSignature key(param); + // TODO(zhennan): Hash conv_param for now, need to hash full param if we want to enable cache for + // fused conv + MKLDNNConvSignature key(param.conv_param); key.AddSign(is_train); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. + // Here we can sign the conv op with NDArray because conv primitive will decide the right layout + // for the, so we only need to get the shape and the data type of the arrays. key.AddSign(data); - key.AddSign(weights); + key.AddSign(weight); key.AddSign(output); - if (bias) - key.AddSign(*bias); + if (bias) key.AddSign(*bias); auto it = fwds.find(key); if (it == fwds.end()) { - MKLDNNConvFullParam full_param; - full_param.conv_param = param; - full_param.mkldnn_param.Init(std::unordered_map()); - MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output); + auto fwd = MKLDNNConvForward(param, is_train, data, weight, bias, output); it = AddToCache(&fwds, key, fwd); } return it->second; } -void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, - const OpContext &ctx, +void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, const OpContext &ctx, MKLDNNConvForward *fwd, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]); - auto data = in_data[conv::kData]; - if (data.IsView() && data.IsMKLDNNData()) - data = data.Reorder2Default(); - - auto weight = in_data[conv::kWeight]; - if (weight.IsView() && weight.IsMKLDNNData()) - weight = weight.Reorder2Default(); - + auto &data = in_data[conv::kData]; + auto &weight = in_data[conv::kWeight]; bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn; - auto data_mem = data.GetMKLDNNDataReorder( - fwd->fwd_pd.src_primitive_desc()); + auto data_mem = data.GetMKLDNNDataReorder(fwd->GetPd().src_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { - // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it - // to the default format for now. + // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it to the default format + // for now. if (weight.IsMKLDNNData()) - // This asks the engine to change the layout of the weight array after - // it's used. + // This asks the engine to change the layout of the weight array after it's used. weight.Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), - param.conv_param.num_group); + weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group); } else { - // For inference, we want to reorder the weight array so we don't need to - // reorder data every time. + // For inference, we want to reorder the weight array so we don't need to reorder data every + // time. if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), - param.conv_param.num_group); - + // We also need to modify the layout on the original weight array. The data conversion happens + // after the weight array is used. + weight.MKLDNNDataReorderAsync(fwd->GetPd().weights_desc()); + weight_mem = GetWeights(weight, fwd->GetPd().weights_desc(), param.conv_param.num_group); } else { weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); + CHECK(weight_mem->get_desc() == fwd->GetPd().weights_desc()); } } mkldnn_output_t out_mem; if (param.mkldnn_param.with_sum) { - out_mem = mkldnn_output_t( - OutDataOp::Noop, - const_cast(out_data[conv::kOut].GetMKLDNNData())); + out_mem = mkldnn_output_t(OutDataOp::Noop, + const_cast(out_data[conv::kOut].GetMKLDNNData())); } else { - out_mem = CreateMKLDNNMem(out_data[conv::kOut], - fwd->fwd_pd.dst_primitive_desc(), req[conv::kOut]); + out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd->GetPd().dst_desc(), req[conv::kOut]); } - const mkldnn::memory *bias_mem = nullptr; + mkldnn_args_map_t net_args; if (!no_bias) { - bias_mem = in_data[conv::kBias].GetMKLDNNData(); + const mkldnn::memory *bias_mem = in_data[conv::kBias].GetMKLDNNData(); + net_args.insert({MKLDNN_ARG_BIAS, *bias_mem}); } - fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd()); + net_args.insert({MKLDNN_ARG_SRC, *data_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), net_args); CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Get()->Submit(); } -void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, +void MKLDNNConvolutionForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { MKLDNNConvFullParam param; param.conv_param = nnvm::get(attrs.parsed); param.mkldnn_param.Init(std::unordered_map()); - auto &fwd = GetConvFwd( - param.conv_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], - param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], - out_data[conv::kOut]); + auto &fwd = + GetConvFwd(param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight], + param.conv_param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]); MKLDNNConvolutionForwardFullFeature(param, ctx, &fwd, in_data, req, out_data); } -class MKLDNNConvBackward { - std::shared_ptr bwd_data; - std::shared_ptr bwd_weight; - // conv::kData - std::shared_ptr out_grad; - std::shared_ptr in_grad; - std::shared_ptr weight; - // conv::kWeight - std::shared_ptr data; - std::shared_ptr output; - std::shared_ptr in_grad_weight; - std::shared_ptr in_grad_bias; - - public: - mkldnn::convolution_backward_data::primitive_desc bwdData_pd; - mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd; - - MKLDNNConvBackward( - const ConvolutionParam ¶m, const NDArray &data, - const NDArray &weights, const NDArray *bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd): - bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)), - bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) { - } - - void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight, - const mkldnn::memory &in_grad) { - if (this->out_grad == nullptr) - this->out_grad = std::shared_ptr(new mkldnn::memory( - bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->out_grad->set_data_handle(out_grad.get_data_handle()); - if (this->in_grad == nullptr) - this->in_grad = std::shared_ptr(new mkldnn::memory( - bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); - else - this->in_grad->set_data_handle(in_grad.get_data_handle()); - if (this->weight == nullptr) - this->weight = std::shared_ptr(new mkldnn::memory( - bwdData_pd.weights_primitive_desc(), weight.get_data_handle())); - else - this->weight->set_data_handle(weight.get_data_handle()); - if (this->bwd_data == nullptr) - this->bwd_data = std::shared_ptr( - new mkldnn::convolution_backward_data( - this->bwdData_pd, mkldnn::primitive::at(*this->out_grad), - mkldnn::primitive::at(*this->weight), *this->in_grad)); - } - - void SetWeightNewMem(const mkldnn::memory &data, - const mkldnn::memory &out_grad, - const mkldnn::memory &in_grad_weight) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - if (this->output == nullptr) - this->output = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->output->set_data_handle(out_grad.get_data_handle()); - if (this->in_grad_weight == nullptr) - this->in_grad_weight = std::shared_ptr( - new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), - in_grad_weight.get_data_handle())); - else - this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); - - if (this->bwd_weight == nullptr) - this->bwd_weight = std::shared_ptr( - new mkldnn::convolution_backward_weights( - this->bwdWeights_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->output), *this->in_grad_weight)); - } - - void SetWeightNewMem(const mkldnn::memory &data, - const mkldnn::memory &out_grad, - const mkldnn::memory &in_grad_weight, - const mkldnn::memory &in_grad_bias) { - if (this->data == nullptr) - this->data = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); - else - this->data->set_data_handle(data.get_data_handle()); - if (this->output == nullptr) - this->output = std::shared_ptr(new mkldnn::memory( - bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); - else - this->output->set_data_handle(out_grad.get_data_handle()); - if (this->in_grad_weight == nullptr) - this->in_grad_weight = std::shared_ptr( - new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), - in_grad_weight.get_data_handle())); - else - this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); - - if (this->in_grad_bias == nullptr) - this->in_grad_bias = std::shared_ptr( - new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(), - in_grad_bias.get_data_handle())); - else - this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); - if (this->bwd_weight == nullptr) - this->bwd_weight = std::shared_ptr( - new mkldnn::convolution_backward_weights( - this->bwdWeights_pd, mkldnn::primitive::at(*this->data), - mkldnn::primitive::at(*this->output), *this->in_grad_weight, - *this->in_grad_bias)); - } - - const mkldnn::convolution_backward_data &GetBwdData() const { - return *bwd_data; - } - - const mkldnn::convolution_backward_weights &GetBwdWeights() const { - return *bwd_weight; - } -}; +MKLDNNConvBackward::MKLDNNConvBackward(const MKLDNNConvFullParam ¶m, const NDArray &data, + const NDArray &weight, const NDArray *bias, + const NDArray &output) { + const auto fwd_pd = GetConvFwdImpl(param, true, data, weight, bias, output); + bwd_data_pd_ = GetConvBwdData(param.conv_param, data, weight, output, *fwd_pd); + bwd_weight_pd_ = GetConvBwdWeights(param.conv_param, data, weight, bias, output, *fwd_pd); + bwd_data_ = std::make_shared(GetDataPd()); + bwd_weight_ = std::make_shared(GetWeightsPd()); +} -static inline MKLDNNConvBackward &GetConvBwd( - const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights, - const NDArray *bias, const NDArray &output, - const mkldnn::convolution_forward::primitive_desc &fwd_pd) { +static inline MKLDNNConvBackward &GetConvBwd(const MKLDNNConvFullParam ¶m, const NDArray &data, + const NDArray &weight, const NDArray *bias, + const NDArray &output) { + using mkldnn_conv_bwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map bwds; + static thread_local mkldnn_conv_bwd_map bwds; #else - static MX_THREAD_LOCAL std::unordered_map bwds; + static MX_THREAD_LOCAL mkldnn_conv_bwd_map bwds; #endif - const ConvolutionParam& param = nnvm::get(attrs.parsed); - MKLDNNConvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. + // TODO(zhennan): Hash conv_param for now, need to hash full param if we want to enable cache for + // fused conv + MKLDNNConvSignature key(param.conv_param); + // Here we can sign the conv op with NDArray because conv primitive will decide the right layout + // for the, so we only need to get the shape and the data type of the arrays. key.AddSign(data); - key.AddSign(weights); + key.AddSign(weight); key.AddSign(output); - if (bias) - key.AddSign(*bias); - + if (bias) key.AddSign(*bias); auto it = bwds.find(key); if (it == bwds.end()) { - MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd); + auto bwd = MKLDNNConvBackward(param, data, weight, bias, output); it = AddToCache(&bwds, key, bwd); } return it->second; @@ -617,69 +460,53 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct full_param.conv_param = nnvm::get(attrs.parsed); full_param.mkldnn_param.Init(std::unordered_map()); - auto data = inputs[conv::kData + 1]; - if (data.IsView() && data.IsMKLDNNData()) - data = data.Reorder2Default(); + auto &data = inputs[conv::kData + 1]; + auto &weight = inputs[conv::kWeight + 1]; + const auto *bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1]; + auto &out_grad = inputs[conv::kOut]; - auto weight = inputs[conv::kWeight + 1]; - if (weight.IsView() && weight.IsMKLDNNData()) - weight = weight.Reorder2Default(); - - const NDArray* bias = full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1]; - - auto out_grad = inputs[conv::kOut]; - if (out_grad.IsView() && out_grad.IsMKLDNNData()) - out_grad = out_grad.Reorder2Default(); - - mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl( - full_param, ctx.is_train, data, weight, bias, out_grad); const ConvolutionParam ¶m = full_param.conv_param; CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; - MKLDNNConvBackward &convBwd = GetConvBwd(attrs, data, - weight, bias, out_grad, fwd_pd); - auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - convBwd.bwdData_pd.diff_dst_primitive_desc()); + MKLDNNConvBackward &convBwd = GetConvBwd(full_param, data, weight, bias, out_grad); + auto out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetDataPd().diff_dst_desc()); if (req[conv::kData]) { - auto weight_mem = GetWeights(weight, - convBwd.bwdData_pd.weights_primitive_desc(), param.num_group); - auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], - convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); - convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); - MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData()); + auto weight_mem = GetWeights(weight, convBwd.GetDataPd().weights_desc(), param.num_group); + auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], convBwd.GetDataPd().diff_src_desc(), + req[conv::kData]); + MKLDNNStream::Get()->RegisterPrimArgs(convBwd.GetBwdData(), + {{MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}}); CommitOutput(in_grad[conv::kData], in_grad_mem); } if (req[conv::kWeight] || req[conv::kBias]) { - MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data, - weight, bias, out_grad, fwd_pd); - if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() != - convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()) - out_grad_mem = out_grad.GetMKLDNNDataReorder( - convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder( - convBwdWeight.bwdWeights_pd.src_primitive_desc()); + if (convBwd.GetDataPd().diff_dst_desc() != convBwd.GetWeightsPd().diff_dst_desc()) + out_grad_mem = out_grad.GetMKLDNNDataReorder(convBwd.GetWeightsPd().diff_dst_desc()); + auto data_mem = data.GetMKLDNNDataReorder(convBwd.GetWeightsPd().src_desc()); auto in_grad_weight = CreateMKLDNNWeightGrad( - in_grad[conv::kWeight], - convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(), - req[conv::kWeight]); - if (param.no_bias) { - convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, - *in_grad_weight.second); - MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); - } else { - auto in_grad_bias = CreateMKLDNNMem( - in_grad[conv::kBias], - convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); - convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, - *in_grad_weight.second, *in_grad_bias.second); - MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); - CommitOutput(in_grad[conv::kBias], in_grad_bias); + in_grad[conv::kWeight], convBwd.GetWeightsPd().diff_weights_desc(), req[conv::kWeight]); + + mkldnn_args_map_t net_args = {{MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_SRC, *data_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}}; + mkldnn_output_t in_grad_bias; + if (!param.no_bias) { + in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], + convBwd.GetWeightsPd().diff_bias_desc(), + req[conv::kBias]); + net_args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second}); } + MKLDNNStream::Get()->RegisterPrimArgs(convBwd.GetBwdWeights(), net_args); CommitOutput(in_grad[conv::kWeight], in_grad_weight); + // CommitOutput Should run after RegisterPrimArgs for memory dependency + if (!param.no_bias) { + CommitOutput(in_grad[conv::kBias], in_grad_bias); + } } MKLDNNStream::Get()->Submit(); } } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 122ad9fd0686..ddfcecce2bce 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -54,16 +54,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &req, const std::vector &outputs); -/* For convolution. */ -void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data); -void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - /* For deconvolution */ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, @@ -133,6 +123,17 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, #endif #if MXNET_USE_MKLDNN == 100 +/* For convolution. */ +void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); +void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, const mkldnn::memory &out); #endif diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 5290c09ec00d..753be48c11ab 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -526,17 +526,46 @@ class OpSignature { * and the layout to sign the op. */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 void AddSign(const mkldnn::memory &mem) { - auto desc = mem.get_primitive_desc().desc(); - hash = hash * 2 + desc.data.format; - eles.push_back(desc.data.format); + auto desc = mem.get_desc(); + hash = hash * 2 + desc.data.format_kind; + eles.push_back(desc.data.format_kind); hash = hash * 2 + desc.data.data_type; eles.push_back(desc.data.data_type); for (int i = 0; i < desc.data.ndims; i++) { hash = hash * 2 + desc.data.dims[i]; eles.push_back(desc.data.dims[i]); } + switch (desc.data.format_kind) { + case mkldnn_blocked: + hash = hash * 2 + desc.data.ndims; + eles.push_back(desc.data.ndims); + for (int i = 0; i < desc.data.ndims; i++) { + hash = hash * 2 + desc.data.format_desc.blocking.strides[i]; + eles.push_back(desc.data.format_desc.blocking.strides[i]); + } + hash = hash * 2 + desc.data.format_desc.blocking.inner_nblks; + eles.push_back(desc.data.format_desc.blocking.inner_nblks); + for (int i = 0; i < desc.data.format_desc.blocking.inner_nblks; i++) { + hash = hash * 2 + desc.data.format_desc.blocking.inner_blks[i]; + hash = hash * 2 + desc.data.format_desc.blocking.inner_idxs[i]; + eles.push_back(desc.data.format_desc.blocking.inner_blks[i]); + eles.push_back(desc.data.format_desc.blocking.inner_idxs[i]); + } + break; + case mkldnn_format_kind_wino: + hash = hash * 2 + desc.data.format_desc.wino_desc.wino_format; + eles.push_back(desc.data.format_desc.wino_desc.wino_format); + break; + case mkldnn_format_kind_rnn_packed: + hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format; + eles.push_back(desc.data.format_desc.rnn_packed_desc.format); + break; + default: + // nothing need to add + break; + } } #endif @@ -547,7 +576,7 @@ class OpSignature { } void AddSign(const NDArray &arr) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (arr.IsMKLDNNData()) { AddSign(*(arr.GetMKLDNNData())); } else { @@ -555,7 +584,7 @@ class OpSignature { hash = hash * 2 + arr.dtype(); eles.push_back(arr.dtype()); AddSign(arr.shape()); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 } #endif } diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 93606fcde86f..4a8a273334b6 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -34,7 +34,7 @@ #ifdef __CUDACC__ #include "./cast_storage-inl.cuh" #endif // __CUDACC__ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "../nn/mkldnn/mkldnn_base-inl.h" #endif @@ -397,7 +397,7 @@ void CastStorageComputeImpl(const OpContext& ctx, } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) { NDArray ret = output; CastStorageRspRspImpl(ctx, input, &ret); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) { CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type); // If one of them uses the MKLDNN layout. @@ -449,7 +449,7 @@ inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && in_stype == kDefaultStorage && param_stype == kDefaultStorage) { // dns -> dns DispatchMode mode = DispatchMode::kFCompute; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 // If we use MKLDNN and the arrays are in CPU memory, the array may store // MKLDNN layout, we should convert its layout explicitly. if (dev_mask == kCPU)