-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MKLDNN sum OP: implement primitive cache and class refact #12
base: refactor
Are you sure you want to change the base?
Conversation
This commit may add some overhead of managing NDArray for each fallback.
Conflicts: src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
2. Add memory into signature; 3. Try to split BatchNorm into .h file and .cc file. Will finish it after backward code is refactored.
Caching primitive for BatchNorm forward computation
return this->num_args == other.num_args; | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you create a new parameter? You don't have to use MKLDNNParamOpSign. You can use MKLDNNOpSignature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter definition is moved from src/operator/tensor/elemwise_sum.cc to this header file in order to be used by mkldnn_sum.cc, in order to support "add_n"(alias "ElementWiseSum") OP which has valid param, I used MKLDNNParamOpSign.
param = nnvm::get<ElementWiseSumParam>(attrs.parsed); | ||
} else { | ||
memset(¶m, 0, sizeof(param)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter doesn't exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case MKLDNNSumCompute() is invoked via "add_n"(alias "ElementWiseSum") OP, the param will exists, so in order to adapt for these 2 scenarios and provide a unified interface, set param to 0 in case there is no param.
src/operator/nn/mkldnn/mkldnn_sum.cc
Outdated
private: | ||
std::shared_ptr<mkldnn::sum> fwd; | ||
std::vector<std::shared_ptr<mkldnn::memory>> in_data; | ||
mkldnn_output_t out; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldn't use mkldnn_output_t. this is designed to use with CreateMKLDNNMem and CommitOutput. it's not supposed to use for holding the mkldnn memory because the second field is a raw pointer. We shouldn't use a raw pointer to hold memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, will change to use std::shared_ptr to hold the output memory.
src/operator/nn/mkldnn/mkldnn_sum.cc
Outdated
CommitOutput(out_data, out_mem); | ||
stream->RegisterPrim(*(this->fwd)); | ||
auto out_mem = CreateMKLDNNMem(output, this->fwd_pd->dst_primitive_desc(), | ||
req); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldn't call CreateMKLDNNMem twice. The way you restructure the code makes it difficult to work with the interface I designed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to only invoke CreateMKLDNNMem() once and keep a unified SetDataHandle/Execute API interface (pass NDArray as parameters), one way is to define 2 members under
MKLDNNSumFwd: "std::shared_ptrmkldnn::memory out" and "OutDataOp op" to save created memory and data_op and use it later (maybe a little ugly). Please let me know your suggestions as well if there is better choice.
c5b06e8
to
8ba9736
Compare
* Added tutorial for FIT API * Added tests for Fit API tutorial * Updated index.md for the new tutorial to show up * Addressed PR feedback * Addressed PR feedback * Removed spurious comment for Py2 and Py3 compatibility * Address PR feedback * Addressed PR feedback * Fixed typo * Added example to showcase custom event handler * Fixed imports as estimator moved to contrib package * Added a side note to inform about estimator reference being updated by the handlers * Corrected typo * update tutorial * address comments * new line * fix import * fix cached graph * fix import * address comments * fix doc gen * add softmax * add to website index * fix doc string * Fix doc gen (zheng-da#12) * fix warining * fix test * fix * fix * fix print * fix test (zheng-da#13) * fix warning (zheng-da#14) * fix href (zheng-da#15)
Description
Checklist
Essentials
make lint
)Comments