This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN Backward op cache #11301
Merged
Merged
MKLDNN Backward op cache #11301
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
7c9e8e6
Merge pull request #1 from apache/master
ZhennanQin a35bc9a
Merge remote-tracking branch 'upstream/master'
ZhennanQin ad90147
Merge remote-tracking branch 'upstream/master'
ZhennanQin f3f09b7
Enable primitive allocation cache for _backward_LRN.
ZhennanQin d6dc8a8
Enable primitive allocation cache for _backward_Pooling.
ZhennanQin 9e107d2
Enable primitive allocation cache for _backward_Activation.
ZhennanQin b2b71e1
Enable primitive allocation cache for _backward_Deconvolution.
ZhennanQin a58ad33
Enable primitive allocation cache for _backward_BatchNorm.
ZhennanQin 97e0d34
Enable primitive allocation cache for _backward_Convolution
f7b9d30
Enable primitive allocation cache for _backward_Fully_Connected
09ab93a
Merge branch 'master' into backward_op_cache
ZhennanQin e9f6a33
remove fc forward and fix indent problem
huangzhiyuan 2f3f436
remove fc forward and fix convolution indent problem
huangzhiyuan 315abb8
Change log level to FATAL for unreachable code in mkldnn_act.cc
ZhennanQin 21b1a68
remove fc forward and fix convolution indent problem
huangzhiyuan dea6f91
remove useless hint in fc
huangzhiyuan dee9bd6
Merge branch 'master' into backward_op_cache
ZhennanQin dd07d9f
Merge branch 'master' into backward_op_cache
ZhennanQin f160c11
Merge branch 'master' into backward_op_cache
huangzhiyuan 89bafa8
Merge branch 'master' into backward_op_cache
ZhennanQin 913a143
Empty commit to retrigger the CI.
ZhennanQin 75039e1
Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc
ZhennanQin c8e976f
Merge branch 'master' into backward_op_cache
ZhennanQin d92915b
Fix build issue after code merge.
ZhennanQin e0805c8
Merge remote-tracking branch 'upstream/master' into backward_op_cache
ZhennanQin ae4a749
Fix lint after merge
ZhennanQin c34c603
Fix mkldnn act.
ZhennanQin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,7 +84,7 @@ static mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( | |
alg, data_md, alpha); | ||
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); | ||
}); | ||
LOG(INFO) << "Unsupported data type for MKLDNN activation"; | ||
LOG(FATAL) << "Unsupported data type for MKLDNN activation"; | ||
mkldnn::eltwise_forward::desc desc = mkldnn::eltwise_forward::desc( | ||
mkldnn::prop_kind::forward_training, alg, data_md, 0.0); | ||
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); | ||
|
@@ -175,6 +175,100 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | |
stream->Submit(); | ||
} | ||
|
||
static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( | ||
const ActivationParam ¶m, const mkldnn::memory &input_mem, | ||
const mkldnn::memory &diff_dst_memory, int dtype) { | ||
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); | ||
mkldnn::memory::desc data_md = data_mpd.desc(); | ||
mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); | ||
auto cpu_engine = data_mpd.get_engine(); | ||
auto alg = GetMKLDNNActAlgo(param); | ||
|
||
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { | ||
DType alpha = 0; | ||
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, | ||
alg, data_md, alpha); | ||
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); | ||
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); | ||
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, | ||
fw_pdesc); | ||
return bw_pdesc; | ||
}); | ||
LOG(FATAL) << "Unsupported data type for MKLDNN activation"; | ||
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, | ||
alg, data_md, 0.0); | ||
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); | ||
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0); | ||
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, | ||
fw_pdesc); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code shouldn't run here, right? can we return an empty descriptor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should be a unreachable code. As it's hard to create a empty descriptor(mkldnn doesn't provide such a function to do that), I think fatal error is enough. |
||
return bw_pdesc; | ||
} | ||
|
||
class MKLDNNActBackward { | ||
std::shared_ptr<mkldnn::eltwise_backward> bwd; | ||
std::shared_ptr<mkldnn::memory> data; | ||
std::shared_ptr<mkldnn::memory> diff_dst_memory; | ||
std::shared_ptr<mkldnn::memory> diff_src_memory; | ||
|
||
public: | ||
const mkldnn::eltwise_backward::primitive_desc pd; | ||
|
||
explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, | ||
const mkldnn::memory &mem, | ||
const mkldnn::memory &diff_dst_memory) | ||
: pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} | ||
|
||
void SetNewMem(const mkldnn::memory &data, | ||
const mkldnn::memory &diff_dst_memory, | ||
const mkldnn::memory &diff_src_memory) { | ||
if (this->bwd != nullptr) { | ||
this->data->set_data_handle(data.get_data_handle()); | ||
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle()); | ||
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle()); | ||
} else { | ||
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
data.get_primitive_desc(), data.get_data_handle())); | ||
this->diff_dst_memory = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(diff_dst_memory.get_primitive_desc(), | ||
diff_dst_memory.get_data_handle())); | ||
this->diff_src_memory = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(diff_src_memory.get_primitive_desc(), | ||
diff_src_memory.get_data_handle())); | ||
this->bwd = std::shared_ptr<mkldnn::eltwise_backward>( | ||
new mkldnn::eltwise_backward( | ||
this->pd, mkldnn::primitive::at(*this->data), | ||
*this->diff_dst_memory, *this->diff_src_memory)); | ||
} | ||
} | ||
|
||
const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } | ||
}; | ||
|
||
static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m, | ||
const OpContext &ctx, | ||
const NDArray &in_data, | ||
const NDArray &out_grad, | ||
const mkldnn::memory &in_mem) { | ||
#if DMLC_CXX11_THREAD_LOCAL | ||
static thread_local std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds; | ||
#else | ||
static MX_THREAD_LOCAL std::unordered_map<MKLDNNActSignature, MKLDNNActBackward, OpHash> bwds; | ||
#endif | ||
MKLDNNActSignature key(param); | ||
key.AddSign(in_data); | ||
key.AddSign(out_grad); | ||
|
||
auto it = bwds.find(key); | ||
if (it == bwds.end()) { | ||
MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData()); | ||
auto ins_ret = | ||
bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd)); | ||
CHECK(ins_ret.second); | ||
it = ins_ret.first; | ||
} | ||
return it->second; | ||
} | ||
|
||
// For backward relu activation, it's okay to pass "out_data" as "in_data" to this | ||
// function, since the computation only involes non-zeros. | ||
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | ||
|
@@ -200,30 +294,13 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx | |
// descriptor. Otherwise, the perf will suffer. | ||
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) | ||
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); | ||
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); | ||
mkldnn::memory::desc data_md = data_mpd.desc(); | ||
mkldnn::memory::desc diff_md = diff_dst_memory->get_primitive_desc().desc(); | ||
auto cpu_engine = data_mpd.get_engine(); | ||
|
||
MKLDNNActBackward &bwd = | ||
GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); | ||
MKLDNNStream *stream = MKLDNNStream::Get(); | ||
auto alg = GetMKLDNNActAlgo(param); | ||
mkldnn_output_t diff_src_memory; | ||
|
||
MSHADOW_REAL_TYPE_SWITCH(in_buffer.dtype(), DType, { | ||
DType alpha = 0; | ||
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, | ||
alg, data_md, alpha); | ||
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); | ||
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); | ||
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, | ||
fw_pdesc); | ||
|
||
diff_src_memory = CreateMKLDNNMem(in_grad, | ||
bw_pdesc.diff_src_primitive_desc(), req); | ||
stream->RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem, | ||
*diff_dst_memory, | ||
*diff_src_memory.second)); | ||
}); | ||
mkldnn_output_t diff_src_memory = | ||
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); | ||
bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); | ||
stream->RegisterPrim(bwd.GetBwd()); | ||
CommitOutput(in_grad, diff_src_memory); | ||
stream->Submit(); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to it for all operators?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LRN is a special operator that MKLDNN only generates output[0], but for default cpu backward computing, it requires output[1] as well, making opcheck fails eventually. After copying output[1], this problem can be fixed.
Maybe you're thinking if it's necessary for all operators. Firstly, we only found this issue on LRN. Secondly, it would make results different before and after enabling opcheck from accuracy losing accumulation. Thirdly, it would hurt opcheck performance a lot. So I prefer not applying this to all operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for explaining. i just didn't understand why we needed this.