Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN Pooling #16272

Merged
merged 8 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <utility>
#include <mkldnn.hpp>
Expand All @@ -43,60 +43,48 @@ class MKLDNNPoolingFwd {
const int padding_t, const int padding_b,
const int padding_l, const int padding_r,
const mkldnn::algorithm alg_kind,
const bool with_workspace, const bool is_train) :
is_train_(is_train),
const bool with_workspace, const bool is_train):
with_workspace_(with_workspace),
alg_kind_(alg_kind),
fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) {
fwd_(nullptr) {
Init(input, output,
kernel_h, kernel_w, stride_h, stride_w,
padding_t, padding_b, padding_l, padding_r);
padding_t, padding_b, padding_l, padding_r,
is_train, alg_kind);
}

~MKLDNNPoolingFwd() {}
void SetNewMem(const NDArray& in_data,
const NDArray& out_data,
const OpReqType& req,
const mxnet::NDArray *workspace = nullptr);
void Execute(const NDArray& out_data);
void Execute(const NDArray &in_data,
const OpReqType req,
const NDArray& out_data,
const NDArray *workspace);

private:
bool is_train_;
bool with_workspace_;
mkldnn::algorithm alg_kind_;

std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::pooling_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> workspace_;
mkldnn_output_t output_mem_t_;

private:
void Init(const mxnet::NDArray &input,
const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r);
const int padding_l, const int padding_r,
const bool is_train, const mkldnn::algorithm alg_kind);
};

class MKLDNNPoolingBwd {
std::shared_ptr<const mkldnn::pooling_backward> bwd;
std::shared_ptr<mkldnn::memory> diff_dst;
std::shared_ptr<mkldnn::memory> diff_src;
std::shared_ptr<mkldnn::memory> ws;
bool with_workspace;

public:
const mkldnn::pooling_backward::primitive_desc pd;

MKLDNNPoolingBwd(const pooling_backward::primitive_desc &pdesc,
MKLDNNPoolingBwd(const mkldnn::pooling_backward::primitive_desc &pdesc,
bool with_ws);

~MKLDNNPoolingBwd() {}
void SetNewMem(const mxnet::NDArray *workspace,
const mxnet::NDArray &out_grad,
const mkldnn::memory *diff_src_mem);
const mkldnn::pooling_backward &GetBwd();
const mkldnn::pooling_backward::primitive_desc &GetPd();
};
Expand Down Expand Up @@ -141,5 +129,5 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
const NDArray &output);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_POOLING_INL_H_
177 changes: 73 additions & 104 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \author Tao Lv
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include "./mkldnn_pooling-inl.h"

Expand All @@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r) {
// mkldnn::memory::desc
auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc();
const int padding_l, const int padding_r,
const bool is_train, const mkldnn::algorithm alg_kind) {
auto src_md = input.GetMKLDNNData()->get_desc();
mkldnn::memory::dims dims = {src_md.data.dims[0],
src_md.data.dims[1],
static_cast<int>(output.shape()[2]),
static_cast<int>(output.shape()[3])};
auto dst_md = mkldnn::memory::desc({dims},
static_cast<mkldnn::memory::data_type>(src_md.data.data_type),
static_cast<mkldnn::memory::format>(src_md.data.format));
mkldnn::memory::format_tag::any);
const mkldnn::engine engine = CpuEngine::Get()->get_engine();
const mkldnn::algorithm alg_kind = this->alg_kind_;
if (alg_kind != mkldnn::algorithm::pooling_max &&
alg_kind != mkldnn::algorithm::pooling_avg &&
alg_kind != mkldnn::algorithm::pooling_avg_include_padding &&
Expand All @@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
}

mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring;
if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) {
if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) {
prop = mkldnn::prop_kind::forward_training;
}
if (this->is_train_ && prop == mkldnn::prop_kind::forward_scoring) {
if (is_train && prop == mkldnn::prop_kind::forward_scoring) {
LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring";
}

Expand All @@ -67,49 +66,38 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o
const mkldnn::memory::dims kernel = {kernel_h, kernel_w };
// mkldnn::pooling_forward::desc
const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md,
strides, kernel, pad_l, pad_r,
mkldnn::padding_kind::zero);
strides, kernel, pad_l, pad_r);
this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine));
this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc()));
this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc()));
if (this->with_workspace_) {
this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc()));
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_),
mkldnn::primitive::at(*(this->data_)),
*(this->out_),
*(this->workspace_)));
} else {
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_),
mkldnn::primitive::at(*(this->data_)),
*(this->out_)));
}
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_)));

return;
}

void MKLDNNPoolingFwd::SetNewMem(const NDArray& in_data,
const NDArray& out_data,
const OpReqType& req,
const mxnet::NDArray *workspace) {
auto input_mem = in_data.GetMKLDNNData();
output_mem_t_ = CreateMKLDNNMem(out_data, fwd_pd_->dst_primitive_desc(), req);
// mkldnn::memory
this->data_->set_data_handle(input_mem->get_data_handle());
this->out_->set_data_handle(output_mem_t_.second->get_data_handle());
if (this->with_workspace_ && workspace == nullptr) {
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
}
void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
const OpReqType req,
const NDArray& out_data,
const NDArray *workspace) {
NDArray in_buffer = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

auto input_mem = in_buffer.GetMKLDNNData();
auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req);

mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, *input_mem },
{MKLDNN_ARG_DST, *(output_mem_t_.second) },
};

if (this->with_workspace_) {
// mkldnn::memory
auto ws_mem = workspace->GetMKLDNNData();
this->workspace_->set_data_handle(ws_mem->get_data_handle());
auto engine = CpuEngine::Get()->get_engine();
auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
engine, workspace->GetMKLDNNData()->get_data_handle());
args[MKLDNN_ARG_WORKSPACE] = *ws;
}
}

void MKLDNNPoolingFwd::Execute(const NDArray& out_data) {
if (this->fwd_) {
MKLDNNStream::Get()->RegisterPrim(*(this->fwd_));
CommitOutput(out_data, this->output_mem_t_);
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd_), args);
CommitOutput(out_data, output_mem_t_);
MKLDNNStream::Get()->Submit();
} else {
LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr";
Expand Down Expand Up @@ -143,8 +131,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) {
}

mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
const PoolingParam &param, const bool is_train, const memory::desc &data_md,
const memory::desc &out_md) {
const PoolingParam &param, const bool is_train, const mkldnn::memory::desc &data_md,
const mkldnn::memory::desc &out_md) {
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
int kernel_h_, kernel_w_;
if (param.global_pool) {
Expand Down Expand Up @@ -183,19 +171,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(

const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
if (is_train && alg != algorithm::pooling_avg) {
if (is_train && alg != mkldnn::algorithm::pooling_avg) {
kind = mkldnn::prop_kind::forward_training;
}

const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md,
const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md,
{static_cast<int>(stride_h_),
static_cast<int>(stride_w_)},
{kernel_h_, kernel_w_},
{static_cast<int>(pad_t_),
static_cast<int>(pad_l_)},
{static_cast<int>(pad_b_),
static_cast<int>(pad_r_)},
padding_kind::zero);
static_cast<int>(pad_r_)});
return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine);
}

Expand Down Expand Up @@ -223,7 +210,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
auto it = pooling_fwds.find(key);
if (it == pooling_fwds.end()) {
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc();
auto data_md = data.GetMKLDNNData()->get_desc();
int kernel_h_, kernel_w_;
if (param.global_pool) {
kernel_h_ = data_md.data.dims[2];
Expand Down Expand Up @@ -270,42 +257,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
const NDArray &in_data, const OpReqType req,
const NDArray &out_data, const NDArray *workspace) {
auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
fwd.SetNewMem(in_data, out_data, req, workspace);
fwd.Execute(out_data);
fwd.Execute(in_data, req, out_data, workspace);
}

MKLDNNPoolingBwd::MKLDNNPoolingBwd(
const pooling_backward::primitive_desc &pdesc, bool with_ws)
: with_workspace(with_ws), pd(pdesc) {}

void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace,
const mxnet::NDArray &out_grad,
const mkldnn::memory *diff_src_mem) {
if (bwd == nullptr) {
diff_dst.reset(
new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(),
out_grad.GetMKLDNNData()->get_data_handle()));
diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(),
diff_src_mem->get_data_handle()));
if (with_workspace) {
CHECK(workspace != nullptr);
ws.reset(
new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(),
workspace->GetMKLDNNData()->get_data_handle()));
bwd.reset(
new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src));
} else {
bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src));
}
} else {
diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle());
diff_src->set_data_handle(diff_src_mem->get_data_handle());
if (with_workspace) {
CHECK(workspace != nullptr);
ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle());
const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws)
: with_workspace(with_ws), pd(pdesc) {
bwd = std::make_shared<mkldnn::pooling_backward>(pd);
}
}
}

const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() {
return *this->bwd;
Expand Down Expand Up @@ -333,27 +292,31 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,

auto it = pooling_bwds.find(key);
if (it == pooling_bwds.end()) {
auto diff_dst_mem = out_grad.GetMKLDNNData();
// mkldnn v1.0 add reoder to workaround testcase:test_make_subgraph;
// alread fixed in v1.1, will remove after v1.1 is integrated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

NDArray diff_dst_buff = out_grad;
if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) {
diff_dst_buff = out_grad.Reorder2Default();
}
auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
auto input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
const mkldnn::memory::desc data_md = data_mpd.desc();
const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
const mkldnn::memory::desc data_md = input_mem->get_desc();
const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1],
static_cast<int>(out_grad.shape()[2]),
static_cast<int>(out_grad.shape()[3])};
const memory::desc out_md(
{dims}, static_cast<memory::data_type>(data_md.data.data_type),
static_cast<memory::format>(data_md.data.format));
const mkldnn::memory::desc out_md(
{dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
mkldnn::memory::format_tag::any);
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);

const mkldnn::memory::desc diff_md =
diff_dst_mem->get_primitive_desc().desc();
const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
diff_dst_mem->get_desc();
const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1],
static_cast<int>(in_grad.shape()[2]),
static_cast<int>(in_grad.shape()[3])};
const memory::desc diff_in_md(
{dims1}, static_cast<memory::data_type>(diff_md.data.data_type),
static_cast<memory::format>(diff_md.data.format));
const mkldnn::engine cpu_engine = data_mpd.get_engine();
const mkldnn::memory::desc diff_in_md(
{dims1}, static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
mkldnn::memory::format_tag::any);
const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);

int kernel_h_, kernel_w_;
Expand All @@ -379,11 +342,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
stride_h_ = stride_w_ = 1;
}

const pooling_backward::desc desc(
const mkldnn::pooling_backward::desc desc(
alg, diff_in_md, diff_md, {stride_h_, stride_w_},
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_},
mkldnn::padding_kind::zero);
const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
MKLDNNPoolingBwd bwd(pdesc, with_workspace);
it = AddToCache(&pooling_bwds, key, bwd);
}
Expand All @@ -401,14 +363,21 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,

auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
auto diff_src_mem =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);

mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
};
if (workspace != nullptr) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check with_workspace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Done

args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
}

bwd.SetNewMem(workspace, out_grad, diff_src_mem.second);
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Loading