Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Add oneDNN support for npx.reshape and np.reshape (#20563)
Browse files Browse the repository at this point in the history
* Add oneDNN support for npx.reshape and np.reshape

* Fix SupportMKLDNN function for Convolution and Reshape
  • Loading branch information
agrabow committed Sep 17, 2021
1 parent 72e6b16 commit 8832c42
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 29 deletions.
3 changes: 2 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape& shape) {
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNQuantize(int dtype) {
static inline bool IsMKLDNNType(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 || dtype == mshadow::kUint8 ||
dtype == mshadow::kBfloat16;
}
Expand Down Expand Up @@ -218,6 +218,7 @@ bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray& data);
bool SupportMKLDNNBatchDot(const std::vector<NDArray>& inputs, const NDArray& output);
bool SupportMKLDNNLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportMKLDNNReshape(const NDArray& input, const NDArray& output);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
7 changes: 3 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ namespace op {
DMLC_REGISTER_PARAMETER(MKLDNNConvParam);

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray& input) {
if ((params.kernel.ndim() != 1) && (params.kernel.ndim() != 2) && (params.kernel.ndim() != 3))
if (params.kernel.ndim() > 3 || params.kernel.ndim() == 0)
return false;
return SupportMKLDNNQuantize(input.dtype()) &&
((input.shape().ndim() == 3) || (input.shape().ndim() == 4) ||
(input.shape().ndim() == 5));
return IsMKLDNNType(input.dtype()) &&
input.shape().ndim() >= 3 && input.shape().ndim() <= 5;
}

std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
Expand Down
16 changes: 7 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNReshape(const NDArray& input, const NDArray& output) {
const int input_ndims = input.shape().ndim();
const int output_ndims = output.shape().ndim();
return input.shape().Size() > 0 && input_ndims >= 1 && input_ndims <= 6 && output_ndims >= 1 &&
output_ndims <= 6 && IsMKLDNNType(input.dtype());
}

MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType& req,
const NDArray& input,
const NDArray& output) {
Expand Down Expand Up @@ -121,15 +128,6 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const NDArray& input,
const OpReqType& req,
const NDArray& output) {
// For mkldnn non-supported input, it shouldn't hold mkldnn memory, so let's simply fallback to
// naive implement.
const int input_ndims = input.shape().ndim();
if ((input_ndims < 1 || input_ndims > 4) || !SupportMKLDNNQuantize(input.dtype())) {
if (req != kWriteInplace) {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, {input}, {req}, {output});
}
return;
}
if (req == kNullOp)
return;
CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
Expand Down
8 changes: 8 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,14 @@ NNVM_REGISTER_OP(_npx_reshape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ReshapeComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", ReshapeStorageType)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
Expand Down
6 changes: 5 additions & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ static void MKLDNNQuantizedFlattenForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
outputs[1].data().dptr<float>()[0] = inputs[1].data().dptr<float>()[0];
outputs[2].data().dptr<float>()[0] = inputs[2].data().dptr<float>()[0];
}
Expand Down
13 changes: 13 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
}
};

#if MXNET_USE_ONEDNN == 1
bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs);
void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);
#endif // MXNET_USE_ONEDNN == 1

template <typename IType>
inline mxnet::TShape InferReshapeShape(const mxnet::Tuple<IType>& shape,
const mxnet::TShape& dshape,
Expand Down
42 changes: 28 additions & 14 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
#include "./matrix_op-inl.h"
#include "./elemwise_unary_op.h"
#if MXNET_USE_ONEDNN == 1
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_reshape-inl.h"
#include "../nn/mkldnn/mkldnn_slice-inl.h"
#endif

Expand Down Expand Up @@ -114,24 +115,29 @@ DMLC_REGISTER_PARAMETER(DepthToSpaceParam);
DMLC_REGISTER_PARAMETER(SplitParam);

#if MXNET_USE_ONEDNN == 1
static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
// If inputs are supposed to be in MKLDNN format and
// MKLDNN support the data type or the shape. Then convert
// it to the output format and shape
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);

if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

inline static bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
bool ReshapeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
Expand Down Expand Up @@ -228,7 +234,11 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in MKLDNN format and
// MKLDNN support the data type or the shape. Then convert
// it to the output format and shape
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -394,7 +404,11 @@ static void ExpandDimEx(const nnvm::NodeAttrs& attrs,
// If inputs are supposed to be in MKLDNN format and
// MKLDNN support the data type or the shape. Then convert
// it to the output format and shape
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
if (SupportMKLDNNReshape(inputs[0], outputs[0])) {
MKLDNNRun(MKLDNNReshapeForward, attrs, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
}

inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs,
Expand Down

0 comments on commit 8832c42

Please sign in to comment.