-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LRN: caching OP and pass workspace from FW to BW #15
Conversation
This commit may add some overhead of managing NDArray for each fallback.
Conflicts: src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
2. Add memory into signature; 3. Try to split BatchNorm into .h file and .cc file. Will finish it after backward code is refactored.
Caching primitive for BatchNorm forward computation
Add primitive caching for Pooling forward computation
OP primitive cache: use memory as signature for MKLDNN storage type
src/operator/nn/lrn.cc
Outdated
@@ -42,6 +50,22 @@ static bool LRNShape(const nnvm::NodeAttrs& attrs, | |||
out_shape->clear(); | |||
out_shape->push_back(dshape); | |||
out_shape->push_back(dshape); | |||
#if MXNET_USE_MKLDNN == 1 | |||
// Create LRN primitive for getting the workspace size | |||
CHECK_EQ(dshape.ndim(), 4U); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does MXNet LRN always run on 4D arrays? We should provide full compatibility with the original MXNet operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, MKL-DNN is specific for deep learning so that the default input tensor is 4D for all OPs. Currently, the 4D shape is fully supported by MKLDNN and 2D is also can work for OPs.
http://01org.github.io/mkl-dnn/group__c__api__lrn.html
And 2D shape in here is not very computation intensive so it will introduce extra overhead by using MKL-DNN.
I suggest we only enable 4D calculation for MKL-DNN.
static_cast<int>(dshape[1]), | ||
static_cast<int>(dshape[2]), | ||
static_cast<int>(dshape[3])}; | ||
auto src_md = memory::desc({ src_tz_ }, memory::data_type::f32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you use f32 here. will it be a problem when mkldnn supports more types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f64 also can work for MKL-DNN. The problem is that LRNShape can't get the data type information in the symbolic stage. In the short term, I can allocate memory for both fp32 and fp64 in here and then we can select one in runtime but another one is waste. In the long term, an InferShape function should provide the data type and more information even in the symbolic stage.
What's your opinion?
src/operator/nn/lrn.cc
Outdated
int n_out = 2; | ||
#endif | ||
out_type->clear(); | ||
for (int i = 0; i < n_out; ++i ) out_type->push_back(dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the dtype for workspace correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just check the type of workspace and it's FP32 but the better way is to query the data type by MKL-DNN API.
xxx.get_primitive_desc().desc().data.data_type
Will update the code.
|
||
if (this->is_train) { | ||
if (workspace == nullptr) { | ||
this->ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to create one when workspace is null? if it's null, shouldn't it mean workspace isn't required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NO. Two points as below.
- workspace is mandatory in LRN for the training so we always should have the workspace.
- if it's null, it means we use this class with the really stateless method. we don't pass any dependence information between FW and BW or time step.
static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, | ||
const OpContext &ctx, | ||
const NDArray &in_data, | ||
const NDArray *workspace) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
116b4f0
to
b566556
Compare
Thanks for your comments and good suggestions. |
Given the current interface, I think the best solution for workspace is to use stateful compute. The nnvm interface provides such an option. |
@zheng-da I agree with you about stateful OP. |
will use new code and re-submit again. |
* update tests. * fix shape/dtype/storage inference. * fix.
* update tests. * fix shape/dtype/storage inference. * fix.
* update tests. * fix shape/dtype/storage inference. * fix.
* Added tutorial for FIT API * Added tests for Fit API tutorial * Updated index.md for the new tutorial to show up * Addressed PR feedback * Addressed PR feedback * Removed spurious comment for Py2 and Py3 compatibility * Address PR feedback * Addressed PR feedback * Fixed typo * Added example to showcase custom event handler * Fixed imports as estimator moved to contrib package * Added a side note to inform about estimator reference being updated by the handlers * Corrected typo * update tutorial * address comments * new line * fix import * fix cached graph * fix import * address comments * fix doc gen * add softmax * add to website index * fix doc string * Fix doc gen (zheng-da#12) * fix warining * fix test * fix * fix * fix print * fix test (zheng-da#13) * fix warning (zheng-da#14) * fix href (zheng-da#15)
Description
Refine the code structure of LRN and caching OP primitive and memory.
Performance improve from 200 img/sec to 230 image/sec for BS=1 in BDW 2699 v4.
Checklist
Essentials
make lint
)Changes
Comments