-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add reshape op supported by MKL-DNN #12980
Changes from all commits
a5ed940
e34944d
80c2473
6bd9ec4
6798a86
0e008f2
58a12bd
2c2a7f3
b5033a0
d26fd04
91b9bf9
890682d
26075fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,57 @@ DMLC_REGISTER_PARAMETER(StackParam); | |
DMLC_REGISTER_PARAMETER(SqueezeParam); | ||
DMLC_REGISTER_PARAMETER(DepthToSpaceParam); | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
void MKLDNNReshape(const NDArray &in_data, const NDArray &out_data) { | ||
MSHADOW_TYPE_SWITCH(in_data.dtype(), DType, { | ||
auto this_mem = in_data.GetMKLDNNData(); | ||
auto out_dptr = out_data.data().dptr<DType>(); | ||
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc(); | ||
mkldnn::memory::desc this_desc = this_pd.desc(); | ||
mkldnn::memory::dims dims(this_desc.data.dims, | ||
this_desc.data.dims + this_desc.data.ndims); | ||
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type); | ||
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc)); | ||
mkldnn::memory::desc data_md(dims, this_dtype, this_format); | ||
mkldnn::memory::primitive_desc pd(data_md, this_pd.get_engine()); | ||
auto temp_mem = mkldnn::memory(pd, out_dptr); | ||
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*this_mem, temp_mem)); | ||
MKLDNNStream::Get()->Submit(); | ||
|
||
// Removing out_data mkl_mem_ and store data in the default format | ||
const_cast<NDArray &>(out_data).InvalidateMKLDNNData(); | ||
}); | ||
} | ||
|
||
static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
CHECK_EQ(inputs.size(), 1U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
// If inputs are supposed to be in MKLDNN format and | ||
// MKLDNNsupport the data type or the shape. Then convert | ||
// it to the output format and shape | ||
if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape()) && req[0] != kAddTo) { | ||
MKLDNNReshape(inputs[0], outputs[0]); | ||
return; | ||
} | ||
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs); | ||
} | ||
|
||
inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int>* in_attrs, | ||
std::vector<int>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 1U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, | ||
out_attrs); | ||
} | ||
#endif | ||
|
||
NNVM_REGISTER_OP(Reshape) | ||
.add_alias("reshape") | ||
.describe(R"code(Reshapes the input array. | ||
|
@@ -174,6 +225,14 @@ If the argument `reverse` is set to 1, then the special values are inferred from | |
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"}) | ||
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) | ||
#if MXNET_USE_MKLDNN == 1 | ||
.set_attr<bool>("TIsMKLDNN", true) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU) | ||
.set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType) | ||
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) { | ||
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; | ||
}) | ||
#else | ||
.set_attr<nnvm::FInplaceOption>("FInplaceOption", | ||
[](const NodeAttrs& attrs) { | ||
return std::vector<std::pair<int, int> >{{0, 0}}; | ||
|
@@ -182,6 +241,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from | |
[](const NodeAttrs& attrs){ | ||
return std::vector<bool>{true}; | ||
}) | ||
#endif | ||
.add_argument("data", "NDArray-or-Symbol", "Input data to reshape.") | ||
.add_arguments(ReshapeParam::__FIELDS__()); | ||
|
||
|
@@ -210,24 +270,18 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, | |
#endif | ||
} | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems this will change the original behavior when MKL-DNN is not enabled. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is followed new style with other mkldnn op, that is, defining InferStorageType function within MKLDNN macro. Most other ops are refactored into this style by Luobao, I guess this one is missing. |
||
static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 1); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
bool ret = ElemwiseStorageType<1, 1, false, false, false>(attrs, dev_mask, dispatch_mode, | ||
in_attrs, out_attrs); | ||
#if MXNET_USE_MKLDNN == 1 | ||
if (dev_mask == mshadow::cpu::kDevMask | ||
&& in_attrs->at(0) == kDefaultStorage | ||
&& out_attrs->at(0) == kDefaultStorage) { | ||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
} | ||
#endif | ||
return ret; | ||
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, | ||
out_attrs); | ||
} | ||
#endif | ||
|
||
NNVM_REGISTER_OP(Flatten) | ||
.add_alias("flatten") | ||
|
@@ -261,7 +315,9 @@ Example:: | |
.set_num_outputs(1) | ||
.set_attr<nnvm::FInferShape>("FInferShape", FlattenShape) | ||
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) | ||
#if MXNET_USE_MKLDNN == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If MKLDNN supports this op, please add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This attribute is in Line 327. |
||
.set_attr<FInferStorageType>("FInferStorageType", FlattenStorageType) | ||
#endif | ||
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_copy" }) | ||
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", FlattenEx) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that this is not the first op which is implemented this way, but is there a reason for the two different if blocks here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing different, it has been merged into one block here. Thanks for your review! :)