-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN Backward op cache #11301
MKLDNN Backward op cache #11301
Changes from 11 commits
7c9e8e6
a35bc9a
ad90147
f3f09b7
d6dc8a8
9e107d2
b2b71e1
a58ad33
97e0d34
f7b9d30
09ab93a
e9f6a33
2f3f436
315abb8
21b1a68
dea6f91
dee9bd6
dd07d9f
f160c11
89bafa8
913a143
75039e1
c8e976f
d92915b
e0805c8
ae4a749
c34c603
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | |
stream->Submit(); | ||
} | ||
|
||
static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( | ||
const ActivationParam ¶m, const mkldnn::memory &input_mem, | ||
const mkldnn::memory &diff_dst_memory, int dtype) { | ||
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); | ||
mkldnn::memory::desc data_md = data_mpd.desc(); | ||
mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); | ||
auto cpu_engine = data_mpd.get_engine(); | ||
auto alg = GetMKLDNNActAlgo(param); | ||
|
||
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { | ||
DType alpha = 0; | ||
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, | ||
alg, data_md, alpha); | ||
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); | ||
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); | ||
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, | ||
fw_pdesc); | ||
return bw_pdesc; | ||
}); | ||
LOG(INFO) << "Unsupported data type for MKLDNN activation"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should use LOG(FATAL) here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems you changed the one above, but not this one. |
||
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, | ||
alg, data_md, 0.0); | ||
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); | ||
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0); | ||
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, | ||
fw_pdesc); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code shouldn't run here, right? can we return an empty descriptor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should be a unreachable code. As it's hard to create a empty descriptor(mkldnn doesn't provide such a function to do that), I think fatal error is enough. |
||
return bw_pdesc; | ||
} | ||
|
||
class MKLDNNActBackward { | ||
std::shared_ptr<mkldnn::eltwise_backward> bwd; | ||
std::shared_ptr<mkldnn::memory> data; | ||
std::shared_ptr<mkldnn::memory> diff_dst_memory; | ||
std::shared_ptr<mkldnn::memory> diff_src_memory; | ||
|
||
public: | ||
const mkldnn::eltwise_backward::primitive_desc pd; | ||
|
||
explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, | ||
const mkldnn::memory &mem, | ||
const mkldnn::memory &diff_dst_memory) | ||
: pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} | ||
|
||
void SetNewMem(const mkldnn::memory &data, | ||
const mkldnn::memory &diff_dst_memory, | ||
const mkldnn::memory &diff_src_memory) { | ||
if (this->bwd != nullptr) { | ||
this->data->set_data_handle(data.get_data_handle()); | ||
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle()); | ||
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle()); | ||
} else { | ||
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
data.get_primitive_desc(), data.get_data_handle())); | ||
this->diff_dst_memory = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(diff_dst_memory.get_primitive_desc(), | ||
diff_dst_memory.get_data_handle())); | ||
this->diff_src_memory = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(diff_src_memory.get_primitive_desc(), | ||
diff_src_memory.get_data_handle())); | ||
this->bwd = std::shared_ptr<mkldnn::eltwise_backward>( | ||
new mkldnn::eltwise_backward( | ||
this->pd, mkldnn::primitive::at(*this->data), | ||
*this->diff_dst_memory, *this->diff_src_memory)); | ||
} | ||
} | ||
|
||
const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } | ||
}; | ||
|
||
static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m, | ||
const OpContext &ctx, | ||
const NDArray &in_data, | ||
const NDArray &out_grad, | ||
const mkldnn::memory &in_mem) { | ||
#if DMLC_CXX11_THREAD_LOCAL | ||
static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds; | ||
#else | ||
static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds; | ||
#endif | ||
MKLDNNActSignature key(param); | ||
key.AddSign(in_data); | ||
key.AddSign(out_grad); | ||
|
||
auto it = bwds.find(key); | ||
if (it == bwds.end()) { | ||
MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData()); | ||
auto ins_ret = | ||
bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd)); | ||
CHECK(ins_ret.second); | ||
it = ins_ret.first; | ||
} | ||
return it->second; | ||
} | ||
|
||
// For backward relu activation, it's okay to pass "out_data" as "in_data" to this | ||
// function, since the computation only involes non-zeros. | ||
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | ||
|
@@ -193,37 +287,20 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx | |
in_buffer = in_data.Reorder2Default(); | ||
|
||
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed); | ||
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); | ||
auto diff_dst_memory = out_buffer.GetMKLDNNData(); | ||
auto input_mem = in_buffer.GetMKLDNNData(); | ||
// We need to make sure the two inputs to eltwise_backward has the same memory | ||
// descriptor. Otherwise, the perf will suffer. | ||
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) | ||
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); | ||
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); | ||
mkldnn::memory::desc data_md = data_mpd.desc(); | ||
mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc(); | ||
auto cpu_engine = data_mpd.get_engine(); | ||
|
||
MKLDNNActBackward &bwd = | ||
GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); | ||
MKLDNNStream *stream = MKLDNNStream::Get(); | ||
auto alg = GetMKLDNNActAlgo(param); | ||
mkldnn_output_t diff_src_memory; | ||
|
||
MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, { | ||
DType alpha = 0; | ||
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, | ||
alg, data_md, alpha); | ||
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); | ||
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); | ||
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, | ||
fw_pdesc); | ||
|
||
diff_src_memory = CreateMKLDNNMem(in_grad, | ||
bw_pdesc.diff_src_primitive_desc(), req); | ||
stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem, | ||
*diff_dst_memory, | ||
*diff_src_memory.second)); | ||
}); | ||
TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you move it here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't remember the details, maybe it's changed by mistake. I've changed it back. |
||
mkldnn_output_t diff_src_memory = | ||
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); | ||
bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); | ||
stream->RegisterPrim(bwd.GetBwd()); | ||
CommitOutput(in_grad, diff_src_memory); | ||
stream->Submit(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to it for all operators?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LRN is a special operator that MKLDNN only generates output[0], but for default cpu backward computing, it requires output[1] as well, making opcheck fails eventually. After copying output[1], this problem can be fixed.
Maybe you're thinking if it's necessary for all operators. Firstly, we only found this issue on LRN. Secondly, it would make results different before and after enabling opcheck from accuracy losing accumulation. Thirdly, it would hurt opcheck performance a lot. So I prefer not applying this to all operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for explaining. i just didn't understand why we needed this.