Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MKLDNN Backward op cache (#11301)
Browse files Browse the repository at this point in the history
* Enable primitive allocation cache for _backward_Activation.

Change-Id: I545628ff68a54cb01b7fef323dc3de9bd47b1a19

* Enable primitive allocation cache for _backward_Deconvolution.

Change-Id: I1e9bf1b9b44bae52068a9c564dff037851e896e5

* Enable primitive allocation cache for _backward_Pooling.

Change-Id: Idbe94e21f1e2ddf711523767194b95beda19b120

* Enable primitive allocation cache for _backward_LRN.

Change-Id: Iefe9f720de719ec2e2f5d24a006602425136711b

* Enable primitive allocation cache for _backward_BatchNorm.

Change-Id: I9e52651bd830b8cb5d2f193076ef51606c9056f9

* Enable primitive allocation cache for _backward_Convolution

Change-Id: I0496fa2394ee036d05c58f3abc1d74af544c7bca

* Enable primitive allocation cache for _backward_Fully_Connected

Change-Id: I8347527ec1271b1518921a74e3581d7d84187429

* remove fc forward and fix indent problem

* remove fc forward and fix convolution indent problem

* Change log level to FATAL for unreachable code in mkldnn_act.cc

* remove fc forward and fix convolution indent problem

* remove useless hint in fc

* Merge branch 'master' into backward_op_cache

* Empty commit to retrigger the CI.

* Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc

* Fix build issue after code merge.

* Fix lint after merge

* Fix mkldnn act.
  • Loading branch information
ZhennanQin authored and eric-haibin-lin committed Sep 13, 2018
1 parent 7735fa6 commit 741635a
Show file tree
Hide file tree
Showing 11 changed files with 844 additions and 248 deletions.
3 changes: 1 addition & 2 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,9 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), out_expected);

bool dispatched = false;
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
// Copy outputs[1] from opcheck reference as backward check needs it.
MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
return;
}
FallBackCompute(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
125 changes: 101 additions & 24 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
alg, data_md, alpha);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
});
LOG(INFO) << "Unsupported data type for MKLDNN activation";
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
Expand Down Expand Up @@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
stream->Submit();
}

static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const ActivationParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
auto cpu_engine = data_mpd.get_engine();
auto alg = GetMKLDNNActAlgo(param);

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
});
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, 0.0);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
}

class MKLDNNActBackward {
std::shared_ptr<mkldnn::eltwise_backward> bwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> diff_dst_memory;
std::shared_ptr<mkldnn::memory> diff_src_memory;

public:
const mkldnn::eltwise_backward::primitive_desc pd;

explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
const mkldnn::memory &mem,
const mkldnn::memory &diff_dst_memory)
: pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}

void SetNewMem(const mkldnn::memory &data,
const mkldnn::memory &diff_dst_memory,
const mkldnn::memory &diff_src_memory) {
if (this->bwd != nullptr) {
this->data->set_data_handle(data.get_data_handle());
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
} else {
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
diff_dst_memory.get_data_handle()));
this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_src_memory.get_primitive_desc(),
diff_src_memory.get_data_handle()));
this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
new mkldnn::eltwise_backward(
this->pd, mkldnn::primitive::at(*this->data),
*this->diff_dst_memory, *this->diff_src_memory));
}
}

const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
};

static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
const OpContext &ctx,
const NDArray &in_data,
const NDArray &out_grad,
const mkldnn::memory &in_mem) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
#endif
MKLDNNActSignature key(param);
key.AddSign(in_data);
key.AddSign(out_grad);

auto it = bwds.find(key);
if (it == bwds.end()) {
MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
auto ins_ret =
bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
// function, since the computation only involes non-zeros.
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
Expand All @@ -200,30 +294,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc();
auto cpu_engine = data_mpd.get_engine();

MKLDNNActBackward &bwd =
GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
auto alg = GetMKLDNNActAlgo(param);
mkldnn_output_t diff_src_memory;

MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);

diff_src_memory = CreateMKLDNNMem(in_grad,
bw_pdesc.diff_src_primitive_desc(), req);
stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
*diff_dst_memory,
*diff_src_memory.second));
});
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
stream->RegisterPrim(bwd.GetBwd());
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,9 @@ class OpCheck {
const std::vector<mxnet::NDArray> &inputs_,
const std::vector<mxnet::OpReqType> &req,
const std::vector<mxnet::NDArray> &outputs_);

void CopyResult(const std::vector<mxnet::NDArray> &outputs_,
const std::vector<size_t>& indice);
};

bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
Expand All @@ -525,6 +528,8 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,

#define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs) \
if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs);
#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \
if (debug) check.CopyResult(outputs, indice);

} // namespace mxnet
#endif
Expand Down
11 changes: 11 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
}
}

void OpCheck::CopyResult(const std::vector<mxnet::NDArray> &outputs_,
const std::vector<size_t> &indice) {
CHECK(!MKLDNNStream::Get()->HasOps());
auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray> &>(outputs_);
for (auto i = indice.begin(); i != indice.end(); ++i) {
auto mem = outputs[*i].GetMKLDNNData();
non_const_outputs_[*i].CopyFrom(*mem);
}
MKLDNNStream::Get()->Submit();
}

bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
bool support_mkldnn,
Expand Down
129 changes: 87 additions & 42 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
}
}

class MKLDNNBNBackward {
std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
std::shared_ptr<mkldnn::memory> data_m;
std::shared_ptr<mkldnn::memory> diff_m;
std::shared_ptr<mkldnn::memory> gradi_m;
std::shared_ptr<mkldnn::memory> mean_m;
std::shared_ptr<mkldnn::memory> var_m;
const std::shared_ptr<mkldnn::memory> weight_m;
const std::shared_ptr<mkldnn::memory> gradw_m;

public:
const t_bn_b_pdesc pd;

explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
: weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
pd(_pd) {}

const mkldnn::memory &GetWeight() const { return *weight_m; }

const mkldnn::memory &GetGradw() const { return *gradw_m; }

void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
const NDArray &mean, const mkldnn::memory &var,
const mkldnn::memory &gradi) {
auto mean_ptr = mean.data().dptr_;
if (bwd == nullptr) {
data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
data.get_data_handle()));
diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
diff.get_data_handle()));
gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
gradi.get_data_handle()));
mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
var.get_data_handle()));
bwd.reset(new mkldnn::batch_normalization_backward(
pd, *data_m, mkldnn::primitive::at(*mean_m),
mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
*gradw_m));
} else {
data_m->set_data_handle(data.get_data_handle());
diff_m->set_data_handle(diff.get_data_handle());
gradi_m->set_data_handle(gradi.get_data_handle());
mean_m->set_data_handle(mean_ptr);
var_m->set_data_handle(var.get_data_handle());
}
}

const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
};

template <typename DType>
static MKLDNNBNBackward &GetBNBackward(
const BatchNormParam &param, const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem, const NDArray &diff_data,
const mkldnn::memory &diff_mem, unsigned flags) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
#endif
MKLDNNBNSignature key(param);
key.AddSign(in_data);
key.AddSign(diff_data);

auto it = bwds.find(key);
if (it == bwds.end()) {
auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
MKLDNNBNBackward bwd(bwd_pd);
auto ins_ret =
bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

template <typename DType>
void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
const std::vector<NDArray> &out_grad,
Expand Down Expand Up @@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
else if (diff.IsDefaultData())
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags);
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());

if (flags & use_scale_shift) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
// TODO(tao): how to reuse this memory?
std::shared_ptr<const mkldnn::memory> weight_mem(
new mkldnn::memory(bwd_pd.weights_primitive_desc()));

DType* weight_buf = reinterpret_cast<DType *>(weight_mem->get_data_handle());
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
Expand All @@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i]; // bias
}

std::shared_ptr<const mkldnn::memory> gradw_mem(
new mkldnn::memory(bwd_pd.diff_weights_primitive_desc()));
// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
DType* moving_mean_ptr = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
DType* moving_var_ptr = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
DType* out_mean_ptr = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
DType* out_var_ptr = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
mkldnn::memory var_mem(bwd_pd.variance_primitive_desc());
mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());

DType minus_mom = (1.0f - param.momentum);
Expand All @@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
variance * minus_mom;
}

std::shared_ptr<const mkldnn::memory> out_mean_mem(
new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr));
std::shared_ptr<const mkldnn::memory> out_var_mem(
new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr));

auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
*data_mem,
mkldnn::primitive::at(*out_mean_mem),
mkldnn::primitive::at(var_mem),
*diff_mem,
*weight_mem,
*gradi_mem,
*gradw_mem);

MKLDNNStream::Get()->RegisterPrim(bn_bwd);
bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->Submit();
} else {
std::shared_ptr<const mkldnn::memory> imean_mem(
new mkldnn::memory(bwd_pd.mean_primitive_desc(),
moving_mean.data().dptr<DType>()));
std::shared_ptr<const mkldnn::memory> ivar_mem(
new mkldnn::memory(bwd_pd.variance_primitive_desc(),
moving_var.data().dptr<DType>()));
auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
*data_mem,
mkldnn::primitive::at(*imean_mem),
mkldnn::primitive::at(*ivar_mem),
*diff_mem,
*weight_mem,
*gradi_mem,
*gradw_mem);

MKLDNNStream::Get()->RegisterPrim(bn_bwd);
bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
*moving_var.GetMKLDNNData(), *gradi_mem);
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->Submit();
}

// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType* gw_buf = reinterpret_cast<DType *>(gradw_mem->get_data_handle());
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
(in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
Expand Down
Loading

0 comments on commit 741635a

Please sign in to comment.