Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Add oneDNN support for npx.reshape and np.reshape #20563

Merged
merged 2 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape& shape) {
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNQuantize(int dtype) {
static inline bool IsMKLDNNType(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == mshadow::kUint8 ||
dtype == mshadow::kBfloat16;
}
Expand Down Expand Up @@ -218,6 +218,7 @@ bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray& data);
bool SupportMKLDNNBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportMKLDNNReshape(const NDArray& input, const NDArray& output);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
7 changes: 3 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ namespace op {
DMLC_REGISTER_PARAMETER(MKLDNNConvParam);

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray& input) {
if ((params.kernel.ndim() != 1) && (params.kernel.ndim() != 2) && (params.kernel.ndim() != 3))
if (params.kernel.ndim() > 3 || params.kernel.ndim() == 0)
return false;
return SupportMKLDNNQuantize(input.dtype()) &&
((input.shape().ndim() == 3) || (input.shape().ndim() == 4) ||
(input.shape().ndim() == 5));
return IsMKLDNNType(input.dtype()) &&
input.shape().ndim() >= 3 && input.shape().ndim() <= 5;
}

std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
Expand Down
16 changes: 7 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNReshape(const NDArray& input, const NDArray& output) {
const int input_ndims = input.shape().ndim();
const int output_ndims = output.shape().ndim();
return input.shape().Size() > 0 && input_ndims >= 1 && input_ndims <= 6 && output_ndims >= 1 &&
output_ndims <= 6 && IsMKLDNNType(input.dtype());
}

MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType& req,
const NDArray& input,
const NDArray& output) {
Expand Down Expand Up @@ -121,15 +128,6 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const NDArray& input,
const OpReqType& req,
const NDArray& output) {
// For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to
// naive implement.
const int input_ndims = input.shape().ndim();
if ((input_ndims < 1 || input_ndims > 4) || !SupportMKLDNNQuantize(input.dtype())) {
if (req != kWriteInplace) {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, {input}, {req}, {output});
}
return;
}
if (req == kNullOp)
return;
CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
Expand Down
8 changes: 8 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,14 @@ NNVM_REGISTER_OP(_npx_reshape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
Expand Down
6 changes: 5 additions & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ static void MKLDNNQuantizedFlattenForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
outputs[1].data().dptr<float>()[0] = inputs[1].data().dptr<float>()[0];
outputs[2].data().dptr<float>()[0] = inputs[2].data().dptr<float>()[0];
}
Expand Down
13 changes: 13 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
}
};

#if MXNET_USE_ONEDNN == 1
bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs);
void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
#endif // MXNET_USE_ONEDNN == 1

template <typename IType>
inline mxnet::TShape InferReshapeShape(const mxnet::Tuple<IType>& shape,
const mxnet::TShape& dshape,
Expand Down
42 changes: 28 additions & 14 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
#include "./matrix_op-inl.h"
#include "./elemwise_unary_op.h"
#if MXNET_USE_ONEDNN == 1
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_reshape-inl.h"
#include "../nn/mkldnn/mkldnn_slice-inl.h"
#endif

Expand Down Expand Up @@ -114,24 +115,29 @@ DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
DMLC_REGISTER_PARAMETER(SplitParam);

#if MXNET_USE_ONEDNN == 1
static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
// If inputs are supposed to be in MKLDNN format and
// MKLDNN support the data type or the shape. Then convert
// it to the output format and shape
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);

if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
Expand Down Expand Up @@ -228,7 +234,11 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in MKLDNN format and
// MKLDNN support the data type or the shape. Then convert
// it to the output format and shape
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -394,7 +404,11 @@ static void ExpandDimEx(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in MKLDNN format and
// MKLDNN support the data type or the shape. Then convert
// it to the output format and shape
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs,
Expand Down