Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add mkldnn lrn (#16223)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongzha1 authored and TaoLv committed Sep 24, 2019
1 parent 0b8805a commit 8ad8b41
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 128 deletions.
12 changes: 6 additions & 6 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include "./lrn-inl.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_lrn-inl.h"
#include "./mkldnn/mkldnn_base-inl.h"
#endif
Expand Down Expand Up @@ -82,7 +82,7 @@ struct LRNGrad {
}
};

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
bool LRNForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -169,7 +169,7 @@ number of kernels in the layer.
.set_attr_parser(ParamParser<LRNParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LRNShape)
.set_attr<nnvm::FInferType>("FInferType", LRNType)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", LRNForwardInferStorageType)
#endif
.set_attr<nnvm::FListInputNames>("FListInputNames",
Expand All @@ -181,7 +181,7 @@ number of kernels in the layer.
return std::vector<std::string>{"output", "tmp_norm"};
})
.set_attr<FCompute>("FCompute<cpu>", LRNCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNComputeExCPU)
#endif
Expand All @@ -192,11 +192,11 @@ number of kernels in the layer.
NNVM_REGISTER_OP(_backward_LRN)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LRNParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", LRNBackwardInferStorageType)
#endif
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", LRNGradComputeExCPU)
// Native compute requires norm while MKLDNN does not so cannot be compared in debug mode
Expand Down
192 changes: 70 additions & 122 deletions src/operator/nn/mkldnn/mkldnn_lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <utility>
#include <mkldnn.hpp>
#include "../lrn-inl.h"
Expand All @@ -34,27 +34,27 @@
namespace mxnet {
namespace op {

inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
inline mkldnn::algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
// TODO(Patric): lrn_within_channel will cause core dump in MKLDNN backward
// Need to confirm with MKLDNN team and fix later
return algorithm::lrn_across_channels;
return mkldnn::algorithm::lrn_across_channels;
}

inline mkldnn::lrn_forward::primitive_desc GetLRNFwdDesc(
const LRNParam &param, const bool is_train, const memory::desc &src_md) {
const LRNParam &param, const bool is_train, const mkldnn::memory::desc &src_md) {
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
const algorithm alg = GetMKLDNNLRNAlgo(param);
const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
const float beta = param.beta;
const int nsize = param.nsize;
const float k = param.knorm;
auto kind = prop_kind::forward_training;
auto kind = mkldnn::prop_kind::forward_training;
if (is_train) {
kind = prop_kind::forward_training;
kind = mkldnn::prop_kind::forward_training;
} else {
kind = prop_kind::forward_scoring;
kind = mkldnn::prop_kind::forward_scoring;
}
lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
mkldnn::lrn_forward::desc fwd_desc(kind, alg, src_md, nsize, alpha, beta, k);
return mkldnn::lrn_forward::primitive_desc(fwd_desc, engine);
}

Expand All @@ -63,13 +63,13 @@ inline mkldnn::lrn_backward::primitive_desc GetLRNBwdDesc(
const mkldnn::memory::desc &diff_md,
const mkldnn::lrn_forward::primitive_desc &lrnFwd_desc) {
mkldnn::engine &engine = CpuEngine::Get()->get_engine();
const algorithm alg = GetMKLDNNLRNAlgo(param);
const mkldnn::algorithm alg = GetMKLDNNLRNAlgo(param);
const float alpha = param.alpha;
const float beta = param.beta;
const int nsize = param.nsize;
const float k = param.knorm;

lrn_backward::desc lrnBwd_desc(alg, data_in_md,
mkldnn::lrn_backward::desc lrnBwd_desc(alg, data_in_md,
diff_md, nsize, alpha, beta, k);
return mkldnn::lrn_backward::primitive_desc(lrnBwd_desc,
engine, lrnFwd_desc);
Expand All @@ -83,33 +83,24 @@ class MKLDNNLRNFwd {
public:
MKLDNNLRNFwd(const LRNParam& param,
bool is_train,
const NDArray &in_data):
is_train(is_train) {
const NDArray &in_data) {
_Init(param, is_train, in_data);
}

~MKLDNNLRNFwd() {}

void SetNewMem(const NDArray &data,
const NDArray &output,
const OpReqType req);

void SetNewMem(const NDArray &in_data,
const mkldnn::memory *out_mem);

void Execute(const NDArray &out_data);
void Execute(const OpContext &ctx,
const NDArray &in_data,
const OpReqType req,
const NDArray &out_data);

mkldnn::lrn_forward &GetFwd();

const mkldnn::memory *GetWs();
mkldnn::lrn_forward::primitive_desc &GetFwdPd();

private:
std::shared_ptr<mkldnn::lrn_forward> fwd;
std::shared_ptr<mkldnn::memory> in_mem;
std::shared_ptr<mkldnn::memory> out_mem;
std::shared_ptr<mkldnn::memory> ws_mem;
mkldnn_output_t output_mem_t;
bool is_train;
mkldnn::lrn_forward::primitive_desc fwd_pd;

private:
void _Init(const LRNParam &param, bool is_train, const NDArray &in_data);
Expand All @@ -119,52 +110,37 @@ void MKLDNNLRNFwd::_Init(const LRNParam &param,
bool is_train,
const NDArray &in_data) {
mkldnn::memory::desc in_data_md =
in_data.GetMKLDNNData()->get_primitive_desc().desc();
mkldnn::lrn_forward::primitive_desc fwd_pd =
in_data.GetMKLDNNData()->get_desc();
this->fwd_pd =
GetLRNFwdDesc(param, is_train, in_data_md);

this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
->get_primitive_desc()));
this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
if (is_train) {
// If it's training, we have to create a workspace memory. Otherwise,
// MKLDNN will have segmentation fault.
ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
*this->ws_mem, *this->out_mem));
} else {
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)),
*(this->out_mem)));
}
}

void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
const NDArray &out_data,
const OpReqType req) {
const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
output_mem_t = CreateMKLDNNMem(out_data, this->out_mem->get_primitive_desc(), req);
this->in_mem->set_data_handle(in_data_mem->get_data_handle());
this->out_mem->set_data_handle(output_mem_t.second->get_data_handle());
this->fwd = std::shared_ptr<mkldnn::lrn_forward>(new mkldnn::lrn_forward(this->fwd_pd));
}

void MKLDNNLRNFwd::SetNewMem(const NDArray &in_data,
const mkldnn::memory *out_mem) {
const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData();
this->in_mem->set_data_handle(in_data_mem->get_data_handle());
this->out_mem->set_data_handle(out_mem->get_data_handle());
}

void MKLDNNLRNFwd::Execute(const NDArray &out_data) {
MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
void MKLDNNLRNFwd::Execute(const OpContext &ctx,
const NDArray &in_data,
const OpReqType req,
const NDArray &out_data) {
auto output_mem_t = CreateMKLDNNMem(out_data, (this->fwd_pd).dst_desc(), req);

mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *in_data.GetMKLDNNData()},
{ MKLDNN_ARG_DST, *output_mem_t.second },
};
std::shared_ptr<mkldnn::memory> workspace;
if (ctx.is_train) {
auto engine = CpuEngine::Get()->get_engine();
workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
args[MKLDNN_ARG_WORKSPACE] = *(workspace);
}
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd), args);
CommitOutput(out_data, output_mem_t);
MKLDNNStream::Get()->Submit();
}

mkldnn::lrn_forward &MKLDNNLRNFwd::GetFwd() { return *this->fwd; }
mkldnn::lrn_forward::primitive_desc &MKLDNNLRNFwd::GetFwdPd() { return this->fwd_pd; }

const mkldnn::memory *MKLDNNLRNFwd::GetWs() { return this->ws_mem.get(); }
// End of LRN Class and its functions

static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
Expand All @@ -180,10 +156,11 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
OpHash> lrn_fwds;
#endif
auto kind_ =
ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
ctx.is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;

MKLDNNLRNSignature key(param);
key.AddSign(kind_);
key.AddSign(static_cast<int>(kind_));
key.AddSign(in_data);

auto it = lrn_fwds.find(key);
Expand All @@ -201,17 +178,12 @@ void MKLDNNLRNForward(const OpContext &ctx, const LRNParam &param,
if (in_buffer.IsView() && in_buffer.IsMKLDNNData())
in_buffer = in_buffer.Reorder2Default();
MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
fwd.SetNewMem(in_buffer, out_data, req);
fwd.Execute(out_data);
fwd.Execute(ctx, in_buffer, req, out_data);
}

// LRN Backward Class
class MKLDNNLRNBwd {
std::shared_ptr<mkldnn::lrn_backward> bwd;
std::shared_ptr<mkldnn::memory> in_data_mem;
std::shared_ptr<mkldnn::memory> diff_dst_mem;
std::shared_ptr<mkldnn::memory> ws_mem;
std::shared_ptr<mkldnn::memory> diff_src_mem;

public:
const mkldnn::lrn_forward::primitive_desc fwd_pd;
Expand All @@ -222,40 +194,26 @@ class MKLDNNLRNBwd {
MKLDNNLRNBwd(const LRNParam &param, const mkldnn::memory::desc in_data_md,
const mkldnn::memory::desc diff_md)
: fwd_pd(GetLRNFwdDesc(param, true, in_data_md)),
bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {}

void SetNewMem(const NDArray &in_data, const NDArray &out_grad,
const mkldnn::memory *ws, const mkldnn::memory *diff_src_mem) {
if (bwd == nullptr) {
this->in_data_mem.reset(
new mkldnn::memory(this->fwd_pd.src_primitive_desc(),
in_data.GetMKLDNNData()->get_data_handle()));
this->diff_dst_mem.reset(
new mkldnn::memory(this->fwd_pd.dst_primitive_desc(),
out_grad.GetMKLDNNData()->get_data_handle()));
this->ws_mem.reset(
new mkldnn::memory(this->fwd_pd.workspace_primitive_desc(),
ws->get_data_handle()));
this->diff_src_mem.reset(
new mkldnn::memory(this->bwd_pd.diff_src_primitive_desc(),
diff_src_mem->get_data_handle()));
this->bwd.reset(new mkldnn::lrn_backward(
this->bwd_pd, mkldnn::primitive::at(*this->in_data_mem),
mkldnn::primitive::at(*this->diff_dst_mem), *this->ws_mem,
*this->diff_src_mem));
} else {
this->in_data_mem->set_data_handle(
in_data.GetMKLDNNData()->get_data_handle());
this->diff_dst_mem->set_data_handle(
out_grad.GetMKLDNNData()->get_data_handle());
this->ws_mem->set_data_handle(ws->get_data_handle());
this->diff_src_mem->set_data_handle(diff_src_mem->get_data_handle());
}
}

void Execute(const NDArray &in_grad, const mkldnn_output_t &diff_src_mem_) {
MKLDNNStream::Get()->RegisterPrim(*(this->bwd));
CommitOutput(in_grad, diff_src_mem_);
bwd_pd(GetLRNBwdDesc(param, in_data_md, diff_md, this->fwd_pd)) {
bwd = std::make_shared<mkldnn::lrn_backward>(bwd_pd);
}

const mkldnn::lrn_backward &GetBwd() const { return *bwd; }

void Execute(const NDArray &out_grad,
const NDArray &in_data,
const NDArray &in_grad,
const mkldnn_output_t &diff_src_mem) {
auto engine = CpuEngine::Get()->get_engine();
auto workspace = std::make_shared<mkldnn::memory>((this->fwd_pd).workspace_desc(), engine);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *in_data.GetMKLDNNData() },
{ MKLDNN_ARG_DIFF_DST, *out_grad.GetMKLDNNData()},
{ MKLDNN_ARG_WORKSPACE, *workspace },
{ MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }
};
MKLDNNStream::Get()->RegisterPrimArgs(*(this->bwd), args);
CommitOutput(in_grad, diff_src_mem);
MKLDNNStream::Get()->Submit();
}
}; // End of LRN Class
Expand All @@ -277,9 +235,9 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
auto it = lrn_bwds.find(key);
if (it == lrn_bwds.end()) {
const mkldnn::memory::desc in_data_md =
in_data.GetMKLDNNData()->get_primitive_desc().desc();
in_data.GetMKLDNNData()->get_desc();
const mkldnn::memory::desc diff_md =
out_grad.GetMKLDNNData()->get_primitive_desc().desc();
out_grad.GetMKLDNNData()->get_desc();
MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
it = AddToCache(&lrn_bwds, key, bwd);
}
Expand All @@ -300,23 +258,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
in_buffer = in_data.Reorder2Default();
}
MKLDNNLRNBwd &bwd = GetLRNBwd(param, in_buffer, in_grad, out_grad);
// Repeat FW for getting workspace
// TODO(Patric): To keep the function stateless, we can't pass workspace
// from LRN forward to backward. We have to re-compute
// LRN forward to get the workspace.
// Will refine this code later.
MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_buffer);
std::shared_ptr<const mkldnn::memory> dst_temp(
new mkldnn::memory(bwd.fwd_pd.dst_primitive_desc()));
fwd.SetNewMem(in_buffer, dst_temp.get());
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());

mkldnn_output_t diff_src_mem =
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_primitive_desc(), req);
bwd.SetNewMem(in_buffer, out_grad, fwd.GetWs(), diff_src_mem.second);
bwd.Execute(in_grad, diff_src_mem);
CreateMKLDNNMem(in_grad, bwd.bwd_pd.diff_src_desc(), req);

bwd.Execute(out_grad, in_buffer, in_grad, diff_src_mem);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H__

0 comments on commit 8ad8b41

Please sign in to comment.