Skip to content

Commit

Permalink
[v1.x] backport apache#17900 "[MKLDNN] support using any format in po…
Browse files Browse the repository at this point in the history
…oling backward" (apache#18067)

* [MKLDNN] support using any format in pooling backward (apache#17900)

* use any format in pooling backward

* use data_type()

* fix backport
  • Loading branch information
ElaineBao authored Apr 16, 2020
1 parent 1afdfce commit b56571d
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
namespace mxnet {
namespace op {

static inline mkldnn::memory::data_type get_data_type(const mkldnn::memory::desc &md) {
return static_cast<mkldnn::memory::data_type>(md.data_type());
}

void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
Expand Down Expand Up @@ -93,7 +97,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
auto engine = CpuEngine::Get()->get_engine();

if (workspace == nullptr) {
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
}

auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
Expand Down Expand Up @@ -290,30 +294,21 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,

auto it = pooling_bwds.find(key);
if (it == pooling_bwds.end()) {
NDArray diff_dst_buff = out_grad;
if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) {
diff_dst_buff = out_grad.Reorder2Default();
}
auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
auto input_mem = in_data.GetMKLDNNData();
const mkldnn::memory::desc data_md = input_mem->get_desc();
const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
static_cast<int>(out_grad.shape()[2]),
static_cast<int>(out_grad.shape()[3])};
const mkldnn::memory::desc out_md(
{dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
mkldnn::memory::format_tag::any);
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
const mkldnn::memory::desc diff_md =
diff_dst_mem->get_desc();
const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
static_cast<int>(in_grad.shape()[2]),
static_cast<int>(in_grad.shape()[3])};
const mkldnn::memory::desc diff_in_md(
{dims1}, static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
mkldnn::memory::format_tag::any);
const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
auto data_md = input_mem->get_desc();

auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(), out_grad.shape().end());
auto any = mkldnn::memory::format_tag::any;
auto dst_md = mkldnn::memory::desc(dst_dims, get_data_type(data_md), any);

// fwd hint
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md);

// creat bwd desc
auto diff_src_dims = mkldnn::memory::dims(in_grad.shape().begin(), in_grad.shape().end());
auto diff_src_md = mkldnn::memory::desc(diff_src_dims, get_data_type(data_md), any);
auto cpu_engine = CpuEngine::Get()->get_engine();;
auto alg = GetMKLDNNPoolAlgo(param);

int kernel_h_, kernel_w_;
if (param.global_pool) {
Expand All @@ -338,10 +333,12 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
stride_h_ = stride_w_ = 1;
}

const mkldnn::pooling_backward::desc desc(
alg, diff_in_md, diff_md, {stride_h_, stride_w_},
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
// use dst_md as diff_dst_md with any format
auto bwd_desc = mkldnn::pooling_backward::desc(
alg, diff_src_md, dst_md, {stride_h_, stride_w_},
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
auto pdesc = mkldnn::pooling_backward::primitive_desc(bwd_desc, cpu_engine, fwd_pd);

MKLDNNPoolingBwd bwd(pdesc, with_workspace);
it = AddToCache(&pooling_bwds, key, bwd);
}
Expand All @@ -355,15 +352,15 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
if (req == kNullOp) {
return;
}

TmpMemMgr::Get()->Init(ctx.requested[0]);

auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
auto diff_src_mem =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);

auto diff_dst_mem = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc());
auto diff_src_mem = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
{MKLDNN_ARG_DIFF_DST, *diff_dst_mem},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
};
if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
Expand Down

0 comments on commit b56571d

Please sign in to comment.