diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index 9b9f0193979b..827fb10155cf 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -24,7 +24,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include @@ -43,33 +43,26 @@ class MKLDNNPoolingFwd { const int padding_t, const int padding_b, const int padding_l, const int padding_r, const mkldnn::algorithm alg_kind, - const bool with_workspace, const bool is_train) : - is_train_(is_train), + const bool with_workspace, const bool is_train): with_workspace_(with_workspace), - alg_kind_(alg_kind), - fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) { + fwd_(nullptr) { Init(input, output, kernel_h, kernel_w, stride_h, stride_w, - padding_t, padding_b, padding_l, padding_r); + padding_t, padding_b, padding_l, padding_r, + is_train, alg_kind); } ~MKLDNNPoolingFwd() {} - void SetNewMem(const NDArray& in_data, - const NDArray& out_data, - const OpReqType& req, - const mxnet::NDArray *workspace = nullptr); - void Execute(const NDArray& out_data); + void Execute(const NDArray &in_data, + const OpReqType req, + const NDArray& out_data, + const NDArray *workspace); private: - bool is_train_; bool with_workspace_; - mkldnn::algorithm alg_kind_; + std::shared_ptr fwd_pd_; std::shared_ptr fwd_; - std::shared_ptr data_; - std::shared_ptr out_; - std::shared_ptr workspace_; - mkldnn_output_t output_mem_t_; private: void Init(const mxnet::NDArray &input, @@ -77,26 +70,21 @@ class MKLDNNPoolingFwd { const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int padding_t, const int padding_b, - const int padding_l, const int padding_r); + const int padding_l, const int padding_r, + const bool is_train, const mkldnn::algorithm alg_kind); }; class MKLDNNPoolingBwd { std::shared_ptr bwd; - std::shared_ptr diff_dst; - std::shared_ptr diff_src; - std::shared_ptr ws; bool with_workspace; public: const mkldnn::pooling_backward::primitive_desc pd; - MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc, + MKLDNNPoolingBwd(const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws); ~MKLDNNPoolingBwd() {} - void SetNewMem(const mxnet::NDArray *workspace, - const mxnet::NDArray &out_grad, - const mkldnn::memory *diff_src_mem); const mkldnn::pooling_backward &GetBwd(); const mkldnn::pooling_backward::primitive_desc &GetPd(); }; @@ -141,5 +129,5 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, const NDArray &output); } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index f4d681ded78d..1b2449b137a6 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -23,7 +23,7 @@ * \author Tao Lv */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn_pooling-inl.h" @@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int padding_t, const int padding_b, - const int padding_l, const int padding_r) { - // mkldnn::memory::desc - auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc(); + const int padding_l, const int padding_r, + const bool is_train, const mkldnn::algorithm alg_kind) { + auto src_md = input.GetMKLDNNData()->get_desc(); mkldnn::memory::dims dims = {src_md.data.dims[0], src_md.data.dims[1], static_cast(output.shape()[2]), static_cast(output.shape()[3])}; auto dst_md = mkldnn::memory::desc({dims}, static_cast(src_md.data.data_type), - static_cast(src_md.data.format)); + mkldnn::memory::format_tag::any); const mkldnn::engine engine = CpuEngine::Get()->get_engine(); - const mkldnn::algorithm alg_kind = this->alg_kind_; if (alg_kind != mkldnn::algorithm::pooling_max && alg_kind != mkldnn::algorithm::pooling_avg && alg_kind != mkldnn::algorithm::pooling_avg_include_padding && @@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o } mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring; - if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) { + if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) { prop = mkldnn::prop_kind::forward_training; } - if (this->is_train_ && prop == mkldnn::prop_kind::forward_scoring) { + if (is_train && prop == mkldnn::prop_kind::forward_scoring) { LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring"; } @@ -67,49 +66,38 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o const mkldnn::memory::dims kernel = {kernel_h, kernel_w }; // mkldnn::pooling_forward::desc const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md, - strides, kernel, pad_l, pad_r, - mkldnn::padding_kind::zero); + strides, kernel, pad_l, pad_r); this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine)); - this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc())); - this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc())); - if (this->with_workspace_) { - this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc())); - this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), - mkldnn::primitive::at(*(this->data_)), - *(this->out_), - *(this->workspace_))); - } else { - this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), - mkldnn::primitive::at(*(this->data_)), - *(this->out_))); - } + this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_))); + return; } -void MKLDNNPoolingFwd::SetNewMem(const NDArray& in_data, - const NDArray& out_data, - const OpReqType& req, - const mxnet::NDArray *workspace) { - auto input_mem = in_data.GetMKLDNNData(); - output_mem_t_ = CreateMKLDNNMem(out_data, fwd_pd_->dst_primitive_desc(), req); - // mkldnn::memory - this->data_->set_data_handle(input_mem->get_data_handle()); - this->out_->set_data_handle(output_mem_t_.second->get_data_handle()); - if (this->with_workspace_ && workspace == nullptr) { - LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; - } +void MKLDNNPoolingFwd::Execute(const NDArray &in_data, + const OpReqType req, + const NDArray& out_data, + const NDArray *workspace) { + NDArray in_buffer = in_data; + if (in_data.IsView() && in_data.IsMKLDNNData()) + in_buffer = in_data.Reorder2Default(); + + auto input_mem = in_buffer.GetMKLDNNData(); + auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req); + + mkldnn_args_map_t args = { + {MKLDNN_ARG_SRC, *input_mem }, + {MKLDNN_ARG_DST, *(output_mem_t_.second) }, + }; if (this->with_workspace_) { - // mkldnn::memory - auto ws_mem = workspace->GetMKLDNNData(); - this->workspace_->set_data_handle(ws_mem->get_data_handle()); + auto engine = CpuEngine::Get()->get_engine(); + auto ws = std::make_shared((*(this->fwd_pd_)).workspace_desc(), + engine, workspace->GetMKLDNNData()->get_data_handle()); + args[MKLDNN_ARG_WORKSPACE] = *ws; } -} - -void MKLDNNPoolingFwd::Execute(const NDArray& out_data) { if (this->fwd_) { - MKLDNNStream::Get()->RegisterPrim(*(this->fwd_)); - CommitOutput(out_data, this->output_mem_t_); + MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd_), args); + CommitOutput(out_data, output_mem_t_); MKLDNNStream::Get()->Submit(); } else { LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr"; @@ -143,8 +131,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) { } mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( - const PoolingParam ¶m, const bool is_train, const memory::desc &data_md, - const memory::desc &out_md) { + const PoolingParam ¶m, const bool is_train, const mkldnn::memory::desc &data_md, + const mkldnn::memory::desc &out_md) { CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; int kernel_h_, kernel_w_; if (param.global_pool) { @@ -183,19 +171,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; - if (is_train && alg != algorithm::pooling_avg) { + if (is_train && alg != mkldnn::algorithm::pooling_avg) { kind = mkldnn::prop_kind::forward_training; } - const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, + const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, {static_cast(stride_h_), static_cast(stride_w_)}, {kernel_h_, kernel_w_}, {static_cast(pad_t_), static_cast(pad_l_)}, {static_cast(pad_b_), - static_cast(pad_r_)}, - padding_kind::zero); + static_cast(pad_r_)}); return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine); } @@ -223,7 +210,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, auto it = pooling_fwds.find(key); if (it == pooling_fwds.end()) { CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; - auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc(); + auto data_md = data.GetMKLDNNData()->get_desc(); int kernel_h_, kernel_w_; if (param.global_pool) { kernel_h_ = data_md.data.dims[2]; @@ -270,42 +257,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace) { auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); - fwd.SetNewMem(in_data, out_data, req, workspace); - fwd.Execute(out_data); + fwd.Execute(in_data, req, out_data, workspace); } MKLDNNPoolingBwd::MKLDNNPoolingBwd( - const pooling_backward::primitive_desc &pdesc, bool with_ws) - : with_workspace(with_ws), pd(pdesc) {} - -void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace, - const mxnet::NDArray &out_grad, - const mkldnn::memory *diff_src_mem) { - if (bwd == nullptr) { - diff_dst.reset( - new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(), - out_grad.GetMKLDNNData()->get_data_handle())); - diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(), - diff_src_mem->get_data_handle())); - if (with_workspace) { - CHECK(workspace != nullptr); - ws.reset( - new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(), - workspace->GetMKLDNNData()->get_data_handle())); - bwd.reset( - new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src)); - } else { - bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src)); - } - } else { - diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle()); - diff_src->set_data_handle(diff_src_mem->get_data_handle()); - if (with_workspace) { - CHECK(workspace != nullptr); - ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle()); + const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws) + : with_workspace(with_ws), pd(pdesc) { + bwd = std::make_shared(pd); } - } -} const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() { return *this->bwd; @@ -333,27 +292,29 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, auto it = pooling_bwds.find(key); if (it == pooling_bwds.end()) { - auto diff_dst_mem = out_grad.GetMKLDNNData(); + NDArray diff_dst_buff = out_grad; + if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) { + diff_dst_buff = out_grad.Reorder2Default(); + } + auto diff_dst_mem = diff_dst_buff.GetMKLDNNData(); auto input_mem = in_data.GetMKLDNNData(); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - const mkldnn::memory::desc data_md = data_mpd.desc(); - const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], + const mkldnn::memory::desc data_md = input_mem->get_desc(); + const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], static_cast(out_grad.shape()[2]), static_cast(out_grad.shape()[3])}; - const memory::desc out_md( - {dims}, static_cast(data_md.data.data_type), - static_cast(data_md.data.format)); + const mkldnn::memory::desc out_md( + {dims}, static_cast(data_md.data.data_type), + mkldnn::memory::format_tag::any); auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md); - const mkldnn::memory::desc diff_md = - diff_dst_mem->get_primitive_desc().desc(); - const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], + diff_dst_mem->get_desc(); + const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], static_cast(in_grad.shape()[2]), static_cast(in_grad.shape()[3])}; - const memory::desc diff_in_md( - {dims1}, static_cast(diff_md.data.data_type), - static_cast(diff_md.data.format)); - const mkldnn::engine cpu_engine = data_mpd.get_engine(); + const mkldnn::memory::desc diff_in_md( + {dims1}, static_cast(diff_md.data.data_type), + mkldnn::memory::format_tag::any); + const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();; const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); int kernel_h_, kernel_w_; @@ -379,11 +340,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, stride_h_ = stride_w_ = 1; } - const pooling_backward::desc desc( + const mkldnn::pooling_backward::desc desc( alg, diff_in_md, diff_md, {stride_h_, stride_w_}, - {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}, - mkldnn::padding_kind::zero); - const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); + {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}); + const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); MKLDNNPoolingBwd bwd(pdesc, with_workspace); it = AddToCache(&pooling_bwds, key, bwd); } @@ -401,14 +361,21 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); auto diff_src_mem = - CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); + CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req); + + mkldnn_args_map_t args = { + {MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())}, + {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }, + }; + if (MKLDNNRequireWorkspace(param) && workspace != nullptr) { + args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData()); + } - bwd.SetNewMem(workspace, out_grad, diff_src_mem.second); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args); CommitOutput(in_grad, diff_src_mem); MKLDNNStream::Get()->Submit(); } } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 485fc1345dfd..9f1edc507c22 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -28,7 +28,7 @@ #if MXNET_USE_NNPACK == 1 #include "../nnpack/nnpack_pooling-inl.h" #endif // MXNET_USE_NNPACK -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_pooling-inl.h" #include "./mkldnn/mkldnn_base-inl.h" #endif // MXNET_USE_MKLDNN @@ -61,7 +61,7 @@ void PoolingParamParser(nnvm::NodeAttrs *attrs) { } int GetNumOutputs(const PoolingParam ¶m) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 return MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param) ? 2 : 1; #else return 1; @@ -69,7 +69,7 @@ int GetNumOutputs(const PoolingParam ¶m) { } int GetNumBackInputs(const PoolingParam ¶m) { -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 return MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param) ? 5 : 3; #else return 3; @@ -80,7 +80,7 @@ static bool PoolingType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { out_attrs->at(0) = in_attrs->at(0); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 const PoolingParam ¶m = nnvm::get(attrs.parsed); if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) { CHECK_GT(out_attrs->size(), 1U); @@ -138,7 +138,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, oshape[i] = 1; out_shape->clear(); out_shape->push_back(oshape); // save output shape -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif @@ -175,7 +175,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, ConvertLayout(oshape_ncw, mshadow::kNCW, mshadow::kNWC) : oshape_ncw; out_shape->clear(); out_shape->push_back(oshape); // save output shape -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif @@ -213,7 +213,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, ConvertLayout(oshape_nchw, mshadow::kNCHW, mshadow::kNHWC) : oshape_nchw; out_shape->clear(); out_shape->push_back(oshape); // save output shape -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif @@ -255,7 +255,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, ConvertLayout(oshape_ncdhw, mshadow::kNCDHW, mshadow::kNDHWC) : oshape_ncdhw; out_shape->clear(); out_shape->push_back(oshape); // save output shape -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) out_shape->push_back(oshape); // for workspace #endif @@ -264,7 +264,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, return true; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -420,7 +420,7 @@ For each window ``X``, the mathematical expression for Lp pooling is: const PoolingParam ¶m = nnvm::get(attrs.parsed); return GetNumOutputs(param); }) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FNumVisibleOutputs", [](const NodeAttrs& attrs) { return 1; }) #endif @@ -437,13 +437,13 @@ For each window ``X``, the mathematical expression for Lp pooling is: return std::vector{"output"}; }) .set_attr_parser(PoolingParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", PoolingStorageType) #endif .set_attr("FInferType", PoolingType) .set_attr("FInferShape", PoolingShape) .set_attr("FCompute", PoolingCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", PoolingComputeExCPU) #endif @@ -460,14 +460,14 @@ NNVM_REGISTER_OP(_backward_Pooling) "FInplaceOption", [](const NodeAttrs &attrs) { // Different backend requires different FInplaceOption -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 const PoolingParam ¶m = nnvm::get(attrs.parsed); if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) return std::vector >{{1, 0}}; #endif return std::vector >(); }) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -475,7 +475,7 @@ NNVM_REGISTER_OP(_backward_Pooling) BackwardPoolingStorageType) #endif .set_attr_parser(PoolingParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", PoolingGradComputeExCPU) #endif