-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-264] Improve performance of MKLDNN in small batch sizes. #10317
Conversation
@@ -74,6 +74,7 @@ enum NDArrayFormatErr { | |||
kRSPIdxErr, // indices error for row sparse | |||
}; | |||
|
|||
class MKLDNNMemory; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward declarations are going out of style. Is there a reasonable way around this, or does it get messy without this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think it's a good idea to include the header file that defines MKLDNNMemory in ndarray.h
it's better to declare the class in this file. what is the reason of not using forward declaration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one reason is that they cause annoying compile errors when used in pointer classes when code the compiler decides it needs the type in order to generate the destructor code, for instance or during template instantiation of something that uses it directly or indirectly. i’m not going to block the PR over it and if you feel
strongly that you want to use it then fine, but it’s not done much in the code base and that’s probably not an accident.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
between including mkldnn_base-inl.h and forward declaration, i'll choose the latter. NDArray is such a basic class, its header file is included in almost every .cc files and many .h files, including mkldnn_base-inl.h. The other option is to define MKLDNNMemory in NDArray, if it's a preferred way. it's a little weird to me.
src/ndarray/ndarray.cc
Outdated
auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc()); | ||
CHECK_NE(format, ptr_->mkl_mem_->get_primitive_desc().desc().data.format); | ||
auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format); | ||
auto format = ptr_->mkl_mem_->GetDefaultFormat(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please reduce the usage of 'auto'? Usually, 'auto' is best for obvious types or very tedious declarations such as stl map iterators. For this, there's several calls that I have no idea what the object types are. It makes the code hard to understand. There's so many autos that it almost looks like javascript :)
// def_mem points to a memory region in the temp space. It's only valid | ||
// inside an operator. As such, the returned NDArray can only be valid | ||
// inside an operator and the shared point doesn't need to do anything | ||
// when it's destroyed. | ||
ret.ptr_->mkl_mem_ = std::shared_ptr<mkldnn::memory>(def_mem, | ||
[](mkldnn::memory *mem){}); | ||
auto tmp = std::shared_ptr<mkldnn::memory>(def_mem, [](mkldnn::memory *mem){}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see, for example this is a good use for auto, because the type is obvious from the assignment.
8cce42e
to
c8db046
Compare
This reverts commit 71d0dec.
Could you please review this PR? @piiswrong @pengzhao-intel @TaoLv |
@zheng-da Do you have any performance update of this PR? |
This reverts commit 58854be.
Here is the performance before and after the optimizations.
|
weight_buf[channels_ + i] = bias_ptr[i]; // bias | ||
} | ||
memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_); | ||
memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice optimization above; why is the below OMP calls causing overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you referring to all of the OMP directives? The number of channels is in the order of 100. Parallelization overhead is usually larger than the actual computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, thanks! we noticed the same performance issue for smaller networks too (eg: mnist) . Lower OMP_NUM_THREADS (eg: 4 -vs- 36) was giving better performance.
@cjolivier01 @piiswrong this PR resolves performance issues with small batch size inference, can we get this PR into 1.2.0 release please., thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
way too many autos.
mkldnn_memory_format_t GetDefaultFormat(int num_dims); | ||
mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, | ||
mkldnn_memory_format_t format); | ||
|
||
static inline bool same_shape(const TShape &shape, const mkldnn_dims_t dims, int ndims) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don’t use static in a header. in-line is fine by itself.
src/operator/nn/lrn-inl.h
Outdated
@@ -58,8 +58,35 @@ struct LRNParam : public dmlc::Parameter<LRNParam> { | |||
DMLC_DECLARE_FIELD(nsize) | |||
.describe("normalization window width in elements."); | |||
} | |||
|
|||
bool operator==(const LRNParam& other) const { | |||
return (fabs(this->alpha - other.alpha) < 1e-6 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it’s better to check the nsize first because it’s a far less expensive check than fabs()
return true; | ||
} | ||
|
||
static inline bool same_shape(const TShape &shape, int dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
size = pd.get_size(); | ||
} | ||
|
||
explicit MKLDNNMemory(std::shared_ptr<mkldnn::memory> mem): desc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this pointer be passed by reference to reduce the interlocked operation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to use shared_ptr here. MKLDNNMemory needs to own the memory.
explicit MKLDNNMemory(std::shared_ptr<mkldnn::memory> mem): desc( | ||
mem->get_primitive_desc().desc()) { | ||
this->mem = mem; | ||
auto pd = mem->get_primitive_desc(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it isn’t clear what auto is here
LGTM, thanks @zheng-da |
src/ndarray/ndarray.cc
Outdated
auto format = GetDefaultFormat(mkl_mem_->get_primitive_desc().desc()); | ||
CHECK(format != mkl_mem_->get_primitive_desc().desc().data.format); | ||
auto format = mkl_mem_->GetDefaultFormat(); | ||
CHECK(format != mkl_mem_->GetFormat()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHECK_NE()
return GetFormat() != GetDefaultFormat(); | ||
} | ||
|
||
bool SameFormat(mkldnn::memory::primitive_desc pd) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe HaveSameFormat
is a better name.
void ReorderTo(mkldnn::memory *other) const { | ||
std::vector<mkldnn::primitive> net; | ||
net.push_back(mkldnn::reorder(*mem, *other)); | ||
mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use MKLDNNStream here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to immediate action here. MKLDNNStream is designed to collect all MKLDNN operators and submit them in one call.
@@ -30,6 +30,67 @@ | |||
namespace mxnet { | |||
namespace op { | |||
|
|||
class MKLDNNCcForward { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 'Cc' a short name for concat? Please find a more proper name for this class.
std::vector<mkldnn::primitive::at> data_mem; | ||
std::vector<const mkldnn::memory *> data_mem; | ||
data_md.reserve(num_in_data); | ||
data_mem.reserve(num_in_data); | ||
for (int i =0; i < num_in_data; i++) { | ||
auto tmp_mem = in_data[i].GetMKLDNNData(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please help fix the indents here.
@cjolivier01 I have removed "auto" as much as possible. |
@cjolivier01 are you OK with the PR? |
@cjolivier01 is this good to merge ? |
@piiswrong can this be merged ? |
…he#10317) * Create MKLDNNMemory to cache metadata. * Fix lint error. * Cache concat. * Fix a bug in NDArray. * improve hashing. * don't use omp for gamma and beta in batchnorm. * address the comments. * Avoid computing out mean&var in batchnorm. * Cache LRN. * Fix a bug in LRN. * Fix lint error. * Revert "Avoid computing out mean&var in batchnorm." This reverts commit 71d0dec. * remove more omp in batchnorm. * add comments for MKLDNNMemory. * Revert "improve hashing." This reverts commit 58854be. * Remove unnecessary TODO. * address comments. * Remove additional auto. * Fix compile error. * remove more auto.
…he#10317) * Create MKLDNNMemory to cache metadata. * Fix lint error. * Cache concat. * Fix a bug in NDArray. * improve hashing. * don't use omp for gamma and beta in batchnorm. * address the comments. * Avoid computing out mean&var in batchnorm. * Cache LRN. * Fix a bug in LRN. * Fix lint error. * Revert "Avoid computing out mean&var in batchnorm." This reverts commit 71d0dec. * remove more omp in batchnorm. * add comments for MKLDNNMemory. * Revert "improve hashing." This reverts commit 58854be. * Remove unnecessary TODO. * address comments. * Remove additional auto. * Fix compile error. * remove more auto.
Description
The current MKLDNN integration still has a lot of overhead for calling MKLDNN functions and the overhead comes from many places. This PR is to reduce the overheads.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.