From 31c090a8ef50cb0c0297b0268c7bfeb855ede879 Mon Sep 17 00:00:00 2001 From: YixinBao Date: Fri, 10 Apr 2020 14:06:19 +0800 Subject: [PATCH 1/2] [MKLDNN] support using any format in pooling backward (#17900) * use any format in pooling backward * use data_type() --- src/operator/nn/mkldnn/mkldnn_pooling.cc | 62 +++++++++++------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index d2f79700051a..5f75ac49d07c 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -30,6 +30,10 @@ namespace mxnet { namespace op { +static inline mkldnn::memory::data_type get_data_type(const mkldnn::memory::desc &md) { + return static_cast(md.data_type()); +} + void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &output, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, @@ -93,7 +97,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data, auto engine = CpuEngine::Get()->get_engine(); if (workspace == nullptr) { - LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; + LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; } auto ws = std::make_shared((*(this->fwd_pd_)).workspace_desc(), @@ -290,30 +294,21 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, auto it = pooling_bwds.find(key); if (it == pooling_bwds.end()) { - NDArray diff_dst_buff = out_grad; - if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) { - diff_dst_buff = out_grad.Reorder2Default(); - } - auto diff_dst_mem = diff_dst_buff.GetMKLDNNData(); auto input_mem = in_data.GetMKLDNNData(); - const mkldnn::memory::desc data_md = input_mem->get_desc(); - const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], - static_cast(out_grad.shape()[2]), - static_cast(out_grad.shape()[3])}; - const mkldnn::memory::desc out_md( - {dims}, static_cast(data_md.data.data_type), - mkldnn::memory::format_tag::any); - auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md); - const mkldnn::memory::desc diff_md = - diff_dst_mem->get_desc(); - const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], - static_cast(in_grad.shape()[2]), - static_cast(in_grad.shape()[3])}; - const mkldnn::memory::desc diff_in_md( - {dims1}, static_cast(diff_md.data.data_type), - mkldnn::memory::format_tag::any); - const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();; - const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); + auto data_md = input_mem->get_desc(); + + auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(), out_grad.shape().end()); + auto any = mkldnn::memory::format_tag::any; + auto dst_md = mkldnn::memory::desc(dst_dims, get_data_type(data_md), any); + + // fwd hint + auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md); + + // creat bwd desc + auto diff_src_dims = mkldnn::memory::dims(in_grad.shape().begin(), in_grad.shape().end()); + auto diff_src_md = mkldnn::memory::desc(diff_src_dims, get_data_type(data_md), any); + auto cpu_engine = CpuEngine::Get()->get_engine();; + auto alg = GetMKLDNNPoolAlgo(param); int kernel_h_, kernel_w_; if (param.global_pool) { @@ -338,10 +333,11 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, stride_h_ = stride_w_ = 1; } - const mkldnn::pooling_backward::desc desc( - alg, diff_in_md, diff_md, {stride_h_, stride_w_}, - {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}); - const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); + // use dst_md as diff_dst_md with any format + auto bwd_desc = mkldnn::pooling_backward::desc(alg, diff_src_md, dst_md, + strides, kernel, pad_l, pad_r); + auto pdesc = mkldnn::pooling_backward::primitive_desc(bwd_desc, cpu_engine, fwd_pd); + MKLDNNPoolingBwd bwd(pdesc, with_workspace); it = AddToCache(&pooling_bwds, key, bwd); } @@ -355,15 +351,15 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, if (req == kNullOp) { return; } + TmpMemMgr::Get()->Init(ctx.requested[0]); auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); - auto diff_src_mem = - CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req); - + auto diff_dst_mem = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc()); + auto diff_src_mem = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req); mkldnn_args_map_t args = { - {MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())}, - {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }, + {MKLDNN_ARG_DIFF_DST, *diff_dst_mem}, + {MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second}, }; if (MKLDNNRequireWorkspace(param) && workspace != nullptr) { args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData()); From 2ee825f68b0d9272b7b58ad8d300d53f9bd59a25 Mon Sep 17 00:00:00 2001 From: Yixin Bao Date: Wed, 15 Apr 2020 15:04:23 +0800 Subject: [PATCH 2/2] fix backport --- src/operator/nn/mkldnn/mkldnn_pooling.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 5f75ac49d07c..f987054375f1 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -334,8 +334,9 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, } // use dst_md as diff_dst_md with any format - auto bwd_desc = mkldnn::pooling_backward::desc(alg, diff_src_md, dst_md, - strides, kernel, pad_l, pad_r); + auto bwd_desc = mkldnn::pooling_backward::desc( + alg, diff_src_md, dst_md, {stride_h_, stride_w_}, + {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}); auto pdesc = mkldnn::pooling_backward::primitive_desc(bwd_desc, cpu_engine, fwd_pd); MKLDNNPoolingBwd bwd(pdesc, with_workspace);