Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MKLDNN Backward op cache #11301

Merged
merged 27 commits into from
Sep 13, 2018
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7c9e8e6
Merge pull request #1 from apache/master
ZhennanQin Jun 21, 2018
a35bc9a
Merge remote-tracking branch 'upstream/master'
ZhennanQin Jun 23, 2018
ad90147
Merge remote-tracking branch 'upstream/master'
ZhennanQin Jun 29, 2018
f3f09b7
Enable primitive allocation cache for _backward_LRN.
ZhennanQin Jun 13, 2018
d6dc8a8
Enable primitive allocation cache for _backward_Pooling.
ZhennanQin Jun 13, 2018
9e107d2
Enable primitive allocation cache for _backward_Activation.
ZhennanQin Jun 12, 2018
b2b71e1
Enable primitive allocation cache for _backward_Deconvolution.
ZhennanQin Jun 13, 2018
a58ad33
Enable primitive allocation cache for _backward_BatchNorm.
ZhennanQin Jun 13, 2018
97e0d34
Enable primitive allocation cache for _backward_Convolution
Jun 13, 2018
f7b9d30
Enable primitive allocation cache for _backward_Fully_Connected
Jun 13, 2018
09ab93a
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 2, 2018
e9f6a33
remove fc forward and fix indent problem
huangzhiyuan Jul 9, 2018
2f3f436
remove fc forward and fix convolution indent problem
huangzhiyuan Jul 9, 2018
315abb8
Change log level to FATAL for unreachable code in mkldnn_act.cc
ZhennanQin Jul 9, 2018
21b1a68
remove fc forward and fix convolution indent problem
huangzhiyuan Jul 11, 2018
dea6f91
remove useless hint in fc
huangzhiyuan Jul 11, 2018
dee9bd6
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 12, 2018
dd07d9f
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 13, 2018
f160c11
Merge branch 'master' into backward_op_cache
huangzhiyuan Jul 13, 2018
89bafa8
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 16, 2018
913a143
Empty commit to retrigger the CI.
ZhennanQin Jul 16, 2018
75039e1
Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc
ZhennanQin Jul 16, 2018
c8e976f
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 25, 2018
d92915b
Fix build issue after code merge.
ZhennanQin Jul 25, 2018
e0805c8
Merge remote-tracking branch 'upstream/master' into backward_op_cache
ZhennanQin Aug 27, 2018
ae4a749
Fix lint after merge
ZhennanQin Aug 27, 2018
c34c603
Fix mkldnn act.
ZhennanQin Aug 31, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,9 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
uint32_t out_expected = param.no_bias ? 2 : 3;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), out_expected);

bool dispatched = false;
// TODO(zhengda) let's disable MKLDNN for FullyConnected for now.
// It seems there is a bug.
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ void LRNComputeExCPU(const nnvm::NodeAttrs &attrs,
MKLDNN_OPCHECK_INIT(false, 1, inputs, outputs);
MKLDNNLRNForward(ctx, param, inputs[0], req[0], outputs[0]);
MKLDNN_OPCHECK_RUN(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
// Copy outputs[1] from opcheck reference as backward check needs it.
MKLDNN_OPCHECK_COPY_RESULT(outputs, std::vector<size_t>{1});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to it for all operators?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LRN is a special operator that MKLDNN only generates output[0], but for default cpu backward computing, it requires output[1] as well, making opcheck fails eventually. After copying output[1], this problem can be fixed.
Maybe you're thinking if it's necessary for all operators. Firstly, we only found this issue on LRN. Secondly, it would make results different before and after enabling opcheck from accuracy losing accumulation. Thirdly, it would hurt opcheck performance a lot. So I prefer not applying this to all operators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining. i just didn't understand why we needed this.

return;
}
FallBackCompute(LRNCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down
127 changes: 102 additions & 25 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
alg, data_md, alpha);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
});
LOG(INFO) << "Unsupported data type for MKLDNN activation";
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, alg, data_md, 0.0);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
Expand Down Expand Up @@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
stream->Submit();
}

static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const ActivationParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
auto cpu_engine = data_mpd.get_engine();
auto alg = GetMKLDNNActAlgo(param);

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
});
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, 0.0);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code shouldn't run here, right? can we return an empty descriptor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be a unreachable code. As it's hard to create a empty descriptor(mkldnn doesn't provide such a function to do that), I think fatal error is enough.

return bw_pdesc;
}

class MKLDNNActBackward {
std::shared_ptr<mkldnn::eltwise_backward> bwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> diff_dst_memory;
std::shared_ptr<mkldnn::memory> diff_src_memory;

public:
const mkldnn::eltwise_backward::primitive_desc pd;

explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
const mkldnn::memory &mem,
const mkldnn::memory &diff_dst_memory)
: pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}

void SetNewMem(const mkldnn::memory &data,
const mkldnn::memory &diff_dst_memory,
const mkldnn::memory &diff_src_memory) {
if (this->bwd != nullptr) {
this->data->set_data_handle(data.get_data_handle());
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
} else {
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
diff_dst_memory.get_data_handle()));
this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_src_memory.get_primitive_desc(),
diff_src_memory.get_data_handle()));
this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
new mkldnn::eltwise_backward(
this->pd, mkldnn::primitive::at(*this->data),
*this->diff_dst_memory, *this->diff_src_memory));
}
}

const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
};

static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
const OpContext &ctx,
const NDArray &in_data,
const NDArray &out_grad,
const mkldnn::memory &in_mem) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds;
#endif
MKLDNNActSignature key(param);
key.AddSign(in_data);
key.AddSign(out_grad);

auto it = bwds.find(key);
if (it == bwds.end()) {
MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
auto ins_ret =
bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
// function, since the computation only involes non-zeros.
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
Expand All @@ -193,37 +287,20 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
in_buffer = in_data.Reorder2Default();

const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
auto diff_dst_memory = out_buffer.GetMKLDNNData();
auto input_mem = in_buffer.GetMKLDNNData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc();
auto cpu_engine = data_mpd.get_engine();

MKLDNNActBackward &bwd =
GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
auto alg = GetMKLDNNActAlgo(param);
mkldnn_output_t diff_src_memory;

MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);

diff_src_memory = CreateMKLDNNMem(in_grad,
bw_pdesc.diff_src_primitive_desc(), req);
stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
*diff_dst_memory,
*diff_src_memory.second));
});
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you move it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't remember the details, maybe it's changed by mistake. I've changed it back.

mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
stream->RegisterPrim(bwd.GetBwd());
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ class OpCheck {
const std::vector<mxnet::NDArray> &inputs_,
const std::vector<mxnet::OpReqType> &req,
const std::vector<mxnet::NDArray> &outputs_);

void CopyResult(const std::vector<mxnet::NDArray> &outputs_,
const std::vector<size_t>& indice);
};

bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
Expand All @@ -513,6 +516,8 @@ bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,

#define MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs) \
if (debug) check.Run(fn, attrs, ctx, inputs, req, outputs);
#define MKLDNN_OPCHECK_COPY_RESULT(outputs, indice) \
if (debug) check.CopyResult(outputs, indice);

} // namespace mxnet
#endif
Expand Down
11 changes: 11 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,17 @@ void OpCheck::Run(mxnet::FCompute fn, const nnvm::NodeAttrs &attrs,
}
}

void OpCheck::CopyResult(const std::vector<mxnet::NDArray> &outputs_,
const std::vector<size_t> &indice) {
CHECK(!MKLDNNStream::Get()->HasOps());
auto non_const_outputs_ = const_cast<std::vector<mxnet::NDArray> &>(outputs_);
for (auto i = indice.begin(); i != indice.end(); ++i) {
auto mem = outputs[*i].GetMKLDNNData();
non_const_outputs_[*i].CopyFrom(*mem);
}
MKLDNNStream::Get()->Submit();
}

bool MKLDNNStorageType(const nnvm::NodeAttrs &attrs,
const int dev_mask,
bool support_mkldnn,
Expand Down
129 changes: 87 additions & 42 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,84 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
}
}

class MKLDNNBNBackward {
std::shared_ptr<mkldnn::batch_normalization_backward> bwd;
std::shared_ptr<mkldnn::memory> data_m;
std::shared_ptr<mkldnn::memory> diff_m;
std::shared_ptr<mkldnn::memory> gradi_m;
std::shared_ptr<mkldnn::memory> mean_m;
std::shared_ptr<mkldnn::memory> var_m;
const std::shared_ptr<mkldnn::memory> weight_m;
const std::shared_ptr<mkldnn::memory> gradw_m;

public:
const t_bn_b_pdesc pd;

explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd)
: weight_m(new mkldnn::memory(_pd.weights_primitive_desc())),
gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())),
pd(_pd) {}

const mkldnn::memory &GetWeight() const { return *weight_m; }

const mkldnn::memory &GetGradw() const { return *gradw_m; }

void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff,
const NDArray &mean, const mkldnn::memory &var,
const mkldnn::memory &gradi) {
auto mean_ptr = mean.data().dptr_;
if (bwd == nullptr) {
data_m.reset(new mkldnn::memory(data.get_primitive_desc(),
data.get_data_handle()));
diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(),
diff.get_data_handle()));
gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(),
gradi.get_data_handle()));
mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr));
var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
var.get_data_handle()));
bwd.reset(new mkldnn::batch_normalization_backward(
pd, *data_m, mkldnn::primitive::at(*mean_m),
mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m,
*gradw_m));
} else {
data_m->set_data_handle(data.get_data_handle());
diff_m->set_data_handle(diff.get_data_handle());
gradi_m->set_data_handle(gradi.get_data_handle());
mean_m->set_data_handle(mean_ptr);
var_m->set_data_handle(var.get_data_handle());
}
}

const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; }
};

template <typename DType>
static MKLDNNBNBackward &GetBNBackward(
const BatchNormParam &param, const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem, const NDArray &diff_data,
const mkldnn::memory &diff_mem, unsigned flags) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNBNSignature, MKLDNNBNBackward, OpHash> bwds;
#endif
MKLDNNBNSignature key(param);
key.AddSign(in_data);
key.AddSign(diff_data);

auto it = bwds.find(key);
if (it == bwds.end()) {
auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
MKLDNNBNBackward bwd(bwd_pd);
auto ins_ret =
bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

template <typename DType>
void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
const std::vector<NDArray> &out_grad,
Expand Down Expand Up @@ -326,17 +404,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc());
else if (diff.IsDefaultData())
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc());
auto bwd_pd = _GetBwd(*data_mem, *diff_mem, param.eps, flags);
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc());

if (flags & use_scale_shift) {
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
// TODO(tao): how to reuse this memory?
std::shared_ptr<const mkldnn::memory> weight_mem(
new mkldnn::memory(bwd_pd.weights_primitive_desc()));

DType* weight_buf = reinterpret_cast<DType *>(weight_mem->get_data_handle());
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
Expand All @@ -349,15 +423,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
weight_buf[channels_ + i] = (beta.data().dptr<DType>())[i]; // bias
}

std::shared_ptr<const mkldnn::memory> gradw_mem(
new mkldnn::memory(bwd_pd.diff_weights_primitive_desc()));
// training but no input mean and variance
if (ctx.is_train && !param.use_global_stats) {
DType* moving_mean_ptr = reinterpret_cast<DType *>(moving_mean.data().dptr<DType>());
DType* moving_var_ptr = reinterpret_cast<DType *>(moving_var.data().dptr<DType>());
DType* out_mean_ptr = reinterpret_cast<DType *>(out_mean.data().dptr<DType>());
DType* out_var_ptr = reinterpret_cast<DType *>(out_var.data().dptr<DType>());
mkldnn::memory var_mem(bwd_pd.variance_primitive_desc());
mkldnn::memory var_mem(bwd.pd.variance_primitive_desc());
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());

DType minus_mom = (1.0f - param.momentum);
Expand All @@ -369,45 +441,18 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
moving_var_ptr[i] = moving_var_ptr[i] * param.momentum +
variance * minus_mom;
}

std::shared_ptr<const mkldnn::memory> out_mean_mem(
new mkldnn::memory(bwd_pd.mean_primitive_desc(), out_mean_ptr));
std::shared_ptr<const mkldnn::memory> out_var_mem(
new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr));

auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
*data_mem,
mkldnn::primitive::at(*out_mean_mem),
mkldnn::primitive::at(var_mem),
*diff_mem,
*weight_mem,
*gradi_mem,
*gradw_mem);

MKLDNNStream::Get()->RegisterPrim(bn_bwd);
bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem);
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->Submit();
} else {
std::shared_ptr<const mkldnn::memory> imean_mem(
new mkldnn::memory(bwd_pd.mean_primitive_desc(),
moving_mean.data().dptr<DType>()));
std::shared_ptr<const mkldnn::memory> ivar_mem(
new mkldnn::memory(bwd_pd.variance_primitive_desc(),
moving_var.data().dptr<DType>()));
auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd,
*data_mem,
mkldnn::primitive::at(*imean_mem),
mkldnn::primitive::at(*ivar_mem),
*diff_mem,
*weight_mem,
*gradi_mem,
*gradw_mem);

MKLDNNStream::Get()->RegisterPrim(bn_bwd);
bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean,
*moving_var.GetMKLDNNData(), *gradi_mem);
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd());
MKLDNNStream::Get()->Submit();
}

// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType* gw_buf = reinterpret_cast<DType *>(gradw_mem->get_data_handle());
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
for (int i = 0; i < channels_; i++) {
if (!param.fix_gamma)
(in_grad[1].data().dptr<DType>())[i] = gw_buf[i];
Expand Down
Loading