-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
ce19fd5
to
436ffa4
Compare
src/operator/tensor/matrix_op.cc
Outdated
// is larger than 2, we should use the default layout. | ||
if (outputs[0].IsMKLDNNData() && inputs[0].shape().ndim() > 2) | ||
const_cast<NDArray &>(outputs[0]).Reorder2Default(); | ||
if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SupportMKLDNNArray
doesn't support 3D tensor, flatten
should have same coverage as reshape
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you're right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the same conditions in SupportMKLDNNReshape
.
@@ -98,62 +75,63 @@ class MKLDNNReshapeForward { | |||
} else { | |||
LOG(FATAL) << "not supported req type: " << req; | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent from Line38 to 77?
@@ -119,12 +119,11 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, | |||
const OpReqType &req, | |||
const NDArray &output); | |||
|
|||
void MKLDNNReshapeForward(const nnvm::NodeAttrs &attrs, | |||
void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to keep both flatten and reshape function declaration here.
: MKLDNNReshapeFwd(req, input, output) {} | ||
}; | ||
|
||
static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to combine GetFlattenForward
and GetRehshapeForward
into one, and call them via passing different template parameter? So that we can still reuse most of the function when implementing other ops like expand_dims
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems cannot combine these two functions into one. Because reshape
op have a parameter ReshapeParam
while flatten
op don't, so when we try to create key
, for reshape
we use MKLDNNReshapeSignature key(ReshapeParam)
, but for flatten
we use OpSignature key
. So, this function should be designed differently.
Also, expand_dims
op also have a parameter, and can reuse this function with reshape
op.
@mxnet-label-bot add [pr-awaiting-review] |
667dd37
to
2c3472f
Compare
@arcadiaphy, it would be highly appreciated if you can help to verify this fix with the java demo case. Hope this PR can fix the issue in #15267. |
@TaoLv I've tested the java demo, problem solved. Thanks! |
@pengzhao-intel @TaoLv @ciyongch CI has passed. Please take a review again. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the improvements.
https://mxnet.incubator.apache.org/versions/master/tutorials/mkldnn/operator_list.html Please also add the OP in the MKLDNN supported list. |
@pengzhao-intel @TaoLv Thanks for your advice. Updated. |
Thanks for your contribution. Merging now. |
* Fix flatten issue before slice op * fix cpplint * address comments * retrigger CI * trigger CI * retrigger CI * use SupportMKLDNNReshape and update operator list
* Fix flatten issue before slice op * fix cpplint * address comments * retrigger CI * trigger CI * retrigger CI * use SupportMKLDNNReshape and update operator list
* Fix flatten issue before slice op * fix cpplint * address comments * retrigger CI * trigger CI * retrigger CI * use SupportMKLDNNReshape and update operator list
* Fix flatten issue before slice op * fix cpplint * address comments * retrigger CI * trigger CI * retrigger CI * use SupportMKLDNNReshape and update operator list
Description
This PR should fix issue #15267. The previous FP32 flatten op seems not work properly in some situations. So, we reimplement it by using mkldnn reshape op.
@pengzhao-intel @ciyongch @TaoLv please help review. Thanks
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments