-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-483] C++ tests for mkldnn convolution/deconvolution operator #11778
Conversation
0ffb0fe
to
3d33273
Compare
@zheng-da can you review |
can you test both conv and deconv? these two operators have almost the same inputs. |
1cb5052
to
6750e75
Compare
@mseth10 please take a look as well |
tests/cpp/operator/mkldnn.cc
Outdated
backwards_ex_outputs[2] = &tmp_bias2; | ||
|
||
for (int i = 0; i < backwards_attrs.num_outputs; i++) | ||
back_req[0] = kWriteTo; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only WriteTo? I think you should test AddTo as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regular fcompute for deconvolution/convolution does not support kAddTo currently
https://github.com/azai91/incubator-mxnet/blob/d79e1ad3294837cac653478045023fd312ceed78/src/operator/nn/convolution-inl.h#L178
0bba2a0
to
d1a6fdb
Compare
tests/cpp/operator/mkldnn.cc
Outdated
continue; | ||
shape[dim] = shape[dim] * num_inputs; | ||
|
||
for (int dim = 0; dim < scale.size(); dim++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right type is size_t, ++dim (pre-increment) is standard C++ idiom
tests/cpp/operator/mkldnn.cc
Outdated
|
||
// Type 1. | ||
NDArray arr(shape, Context()); | ||
in_arrs.emplace_back(arr, "Normal NDArray"); | ||
InitDefaultArray(&in_arrs.back().arr, rand); | ||
for (auto pd : pds) { | ||
if (num_inputs > 1) { | ||
for (int dim = 0; dim < scale.size(); dim++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t, otherwise signed vs unsigned comparison warning
56f575d
to
6cfd95f
Compare
@TaoLv this PR already got many reviews and one approval. Is it okay to merge this PR, or you have further unaddressed comments? |
@eric-haibin-lin, need @azai91 's inputs about the difference of this PR and #13084. |
#13084 is a branch off of this PR that includes the BN unit tests. we should merge this one first and then we can more easily review #13084. @TaoLv @eric-haibin-lin |
tests/cpp/include/test_mkldnn.h
Outdated
|
||
inline NDArray CreateKernelNDArray(TShape kernel, int num_filters, TShape input, | ||
bool is_deconv = false) { | ||
CHECK(kernel.ndim() == 2) << "mkldnn only supports 2d filters on 4d inputs"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHECK_EQ
|
@TaoLv I checked that this PR does not cause regression. |
what MKL-DNN warnings? are they introduced in this PR or were they always present? if they were already present, then let's file an issue and tackle in a separate PR. |
Thanks, @azai91 . I noticed that there were many warnings when built mkldnn cpp test with Makefile. It would be great if you can take a look and fix them. Not necessarily in this PR. |
@TaoLv and @pengzhao-intel if you are good with this PR, please approve. |
} else { | ||
weight_mem = weight.GetMKLDNNData(); | ||
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc()); | ||
} | ||
weight_mem = weight.GetMKLDNNData(); | ||
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need else if these two lines are going to execute irrespective
The warnings are related to batch norm test and implicit_gemm in mshadow. Since they are not being modified in this PR, I think its fine if they are not fixed in this PR. |
@azai91 @anirudh2290 thanks for the help and great works to improve the quality of MKL-DNN. |
Description
Add tests for MKLDNN convolution / deconvolution operator. Need to convert MKLDNN conv/deconv inputs if they are views (#12303).
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments