diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index 5abb6670c9b0..f238e8f72542 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -27,10 +27,10 @@ #include "./activation-inl.h" #include "../mshadow_op.h" #include "../tensor/elemwise_unary_op.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_base-inl.h" #include "./mkldnn/mkldnn_ops-inl.h" -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 #include "../operator_common.h" #include "../../common/utils.h" @@ -91,7 +91,7 @@ struct ActivationGrad { } }; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -150,7 +150,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param), dispatch_mode, in_attrs, out_attrs); } -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 MXNET_OPERATOR_REGISTER_UNARY(Activation) @@ -167,7 +167,7 @@ The following activation functions are supported: )code" ADD_FILELINE) .set_attr_parser(ParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", ActivationStorageType) #endif .set_attr("FListOutputNames", @@ -175,7 +175,7 @@ The following activation functions are supported: return std::vector{"output"}; }) .set_attr("FCompute", ActivationCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ActivationComputeExCPU) #endif @@ -189,7 +189,7 @@ NNVM_REGISTER_OP(_backward_Activation) }) .set_num_outputs(1) .set_attr("TIsBackward", true) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", BackwardActStorageType) #endif .set_attr("FInferShape", ElemwiseShape<-1, 1>) @@ -197,13 +197,13 @@ NNVM_REGISTER_OP(_backward_Activation) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) #endif .set_attr_parser(ParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ActivationGradComputeExCPU) #endif diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h b/src/operator/nn/mkldnn/mkldnn_act-inl.h index 6bf30e3f3bbe..57507a5817bb 100644 --- a/src/operator/nn/mkldnn/mkldnn_act-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors * \file mkldnn_act-inl.h - * \brief MKLDNN(Quantized) Activation operator based on subgraph + * \brief MKLDNN Activation operator * /author Zhiyuan Huang */ @@ -28,20 +28,17 @@ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include "../activation-inl.h" -#include "./mkldnn_ops-inl.h" -#include "./mkldnn_base-inl.h" namespace mxnet { namespace op { mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param); mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( - const ActivationParam& param, bool is_train, - const mkldnn::memory &input_mem, int dtype); + const ActivationParam& param, bool is_train, const mkldnn::memory &input_mem); class MKLDNNActForward { public: @@ -49,14 +46,13 @@ class MKLDNNActForward { MKLDNNActForward(const ActivationParam& param, bool is_train, const NDArray &data, const mkldnn::memory &mem): fwd_pd( - GetActFwdDescImpl(param, is_train, mem, data.dtype())) {} - void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output); - const mkldnn::eltwise_forward &GetFwd() const; + GetActFwdDescImpl(param, is_train, mem)) { + fwd_ = std::make_shared(fwd_pd); + } + const inline mkldnn::eltwise_forward &GetFwd() const; private: std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr out_; }; typedef ParamOpSign MKLDNNActSignature; @@ -67,8 +63,28 @@ MKLDNNActForward &GetActForward(const ActivationParam& param, void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); + +mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( + const ActivationParam ¶m, const mkldnn::memory &input_mem, + const mkldnn::memory &diff_dst_memory); + +class MKLDNNActBackward { + public: + const mkldnn::eltwise_backward::primitive_desc pd; + + explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, + const mkldnn::memory &mem, + const mkldnn::memory &diff_dst_memory): pd( + GetActBwdDescImpl(param, mem, diff_dst_memory)) { + bwd = std::make_shared(pd); + } + const inline mkldnn::eltwise_backward &GetBwd() const; + + private: + std::shared_ptr bwd; +}; } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index e4c829645e13..e2ffd0b40da4 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -23,6 +23,8 @@ * \author Da Zheng */ +#if MXNET_USE_MKLDNN == 100 + #include #include #include @@ -33,10 +35,7 @@ #include #include "../../operator_common.h" #include "mkldnn_act-inl.h" - -#if MXNET_USE_MKLDNN == 1 - -#include +#include "./mkldnn_base-inl.h" namespace mxnet { namespace op { @@ -81,41 +80,19 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( const ActivationParam& param, bool is_train, - const mkldnn::memory &input_mem, int dtype) { - mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - auto cpu_engine = data_mpd.get_engine(); - + const mkldnn::memory &input_mem) { + mkldnn::memory::desc data_md = input_mem.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); auto alg = GetMKLDNNActAlgo(param); auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f); - return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); -} - -void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) { - if (this->data_ == nullptr) - this->data_ = std::make_shared(data.get_primitive_desc(), - data.get_data_handle()); - else - this->data_->set_data_handle(data.get_data_handle()); - CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc()); - if (this->out_ == nullptr) - this->out_ = std::make_shared(fwd_pd.dst_primitive_desc(), - output.get_data_handle()); - else - this->out_->set_data_handle(output.get_data_handle()); - - if (this->fwd_ == nullptr) { - this->fwd_ = std::shared_ptr( - new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data_), - *this->out_)); - } + return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); } -const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const { +const inline mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const { return *fwd_; } @@ -131,7 +108,6 @@ MKLDNNActForward &GetActForward(const ActivationParam& param, key.AddSign(ctx.is_train); key.AddSign(param.act_type); key.AddSign(in_data); - auto it = fwds.find(key); if (it == fwds.end()) { MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem); @@ -153,81 +129,34 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto input_mem = in_buffer.GetMKLDNNData(); MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem); - auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); - fwd.SetNewMem(*input_mem, *out_mem_t.second); - stream->RegisterPrim(fwd.GetFwd()); + auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer); + stream->RegisterPrimArgs(fwd.GetFwd(), + {{ MKLDNN_ARG_SRC, *input_mem}, { MKLDNN_ARG_DST, *out_mem_t.second}}); CommitOutput(out_data, out_mem_t); stream->Submit(); } -static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( +mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( const ActivationParam ¶m, const mkldnn::memory &input_mem, - const mkldnn::memory &diff_dst_memory, int dtype) { - mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); - auto cpu_engine = data_mpd.get_engine(); + const mkldnn::memory &diff_dst_memory) { + mkldnn::memory::desc data_md = input_mem.get_desc(); + mkldnn::memory::desc diff_md = diff_dst_memory.get_desc(); + auto cpu_engine = CpuEngine::Get()->get_engine(); auto alg = GetMKLDNNActAlgo(param); - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType alpha = 0; - mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, alpha); - mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); - mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, - fw_pdesc); - return bw_pdesc; - }); - LOG(FATAL) << "Unsupported data type for MKLDNN activation"; + float alpha = 0; mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, 0.0); + alg, data_md, alpha); mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, fw_pdesc); return bw_pdesc; } -class MKLDNNActBackward { - std::shared_ptr bwd; - std::shared_ptr data; - std::shared_ptr diff_dst_memory; - std::shared_ptr diff_src_memory; - - public: - const mkldnn::eltwise_backward::primitive_desc pd; - - explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, - const mkldnn::memory &mem, - const mkldnn::memory &diff_dst_memory) - : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} - - void SetNewMem(const mkldnn::memory &data, - const mkldnn::memory &diff_dst_memory, - const mkldnn::memory &diff_src_memory) { - if (this->bwd != nullptr) { - this->data->set_data_handle(data.get_data_handle()); - this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle()); - this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle()); - } else { - this->data = std::shared_ptr(new mkldnn::memory( - data.get_primitive_desc(), data.get_data_handle())); - this->diff_dst_memory = std::shared_ptr( - new mkldnn::memory(diff_dst_memory.get_primitive_desc(), - diff_dst_memory.get_data_handle())); - this->diff_src_memory = std::shared_ptr( - new mkldnn::memory(diff_src_memory.get_primitive_desc(), - diff_src_memory.get_data_handle())); - this->bwd = std::shared_ptr( - new mkldnn::eltwise_backward( - this->pd, mkldnn::primitive::at(*this->data), - *this->diff_dst_memory, *this->diff_src_memory)); - } - } - - const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } -}; +const inline mkldnn::eltwise_backward &MKLDNNActBackward::GetBwd() const { + return *bwd; +} static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m, const OpContext &ctx, @@ -274,20 +203,23 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx auto input_mem = in_buffer.GetMKLDNNData(); // We need to make sure the two inputs to eltwise_backward has the same memory // descriptor. Otherwise, the perf will suffer. - if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) - input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); + if (input_mem->get_desc() != diff_dst_memory->get_desc()) + input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc()); MKLDNNActBackward &bwd = GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn_output_t diff_src_memory = - CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); - bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); - stream->RegisterPrim(bwd.GetBwd()); + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req); + mkldnn_args_map_t args = { + { MKLDNN_ARG_SRC, *input_mem }, + { MKLDNN_ARG_DIFF_DST, *diff_dst_memory }, + { MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second }, + }; + stream->RegisterPrimArgs(bwd.GetBwd(), args); CommitOutput(in_grad, diff_src_memory); stream->Submit(); } } // namespace op } // namespace mxnet - #endif diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index ddfcecce2bce..3c83f6b6bc56 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -95,14 +95,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& req, const std::vector& outputs); -/* For activation */ -void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, - const NDArray &out_data); -void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &out_grad, const NDArray &in_data, - const OpReqType &req, const NDArray &in_grad); - void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &data, @@ -133,6 +125,13 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct const std::vector& req, const std::vector& outputs); +/* For activation */ +void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &out_grad, const NDArray &in_data, + const OpReqType &req, const NDArray &in_grad); void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, const mkldnn::memory &out);