From 7a21ebca8bbe17fde49c3b1ca3f31b835a33afb8 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 15 Dec 2017 19:35:57 +0000 Subject: [PATCH] Use MKLDNN sum in more cases. --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 7 +++++- .../tensor/elemwise_binary_op_basic.cc | 6 ++--- src/operator/tensor/elemwise_sum.cc | 25 ++++++++++++++++--- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 48d25022231d..f14030973736 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -122,13 +122,18 @@ static inline bool SupportMKLDNNArray(int dtype, const TShape &shape) { return support; } +static inline bool SupportStorageMKLDNN(int stype) { + return stype == kMKLDNNStorage || stype == kDefaultStorage; +} + static inline bool SupportMKLDNN(int dtype, const TShape &shape) { int ndim = shape.ndim(); return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4); } static inline bool SupportMKLDNN(const NDArray &input) { - return SupportMKLDNN(input.dtype(), input.shape()); + return SupportMKLDNN(input.dtype(), input.shape()) + && SupportStorageMKLDNN(input.storage_type()); } static inline bool SupportMKLDNNConv(const NDArray &input) { diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 1c5ff0ec91d5..ae143684a1d8 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -25,6 +25,7 @@ #include "./elemwise_unary_op.h" #include "./elemwise_binary_op-inl.h" #include "../nn/mkldnn/mkldnn_ops-inl.h" +#include "../nn/mkldnn/mkldnn_base-inl.h" namespace mxnet { namespace op { @@ -37,8 +38,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); #if MXNET_USE_MKLDNN == 1 - if (inputs[0].storage_type() == kMKLDNNStorage - || inputs[1].storage_type() == kMKLDNNStorage) { + if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) { MKLDNNSum_Forward(attrs, ctx, inputs, req[0], outputs[0]); return; } else if (inputs[0].storage_type() == kDefaultStorage @@ -68,7 +68,7 @@ static inline bool ElemwiseAddStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); #if MXNET_USE_MKLDNN == 1 - if ((in_attrs->at(0) == kMKLDNNStorage || in_attrs->at(1) == kMKLDNNStorage) + if ((SupportStorageMKLDNN(in_attrs->at(0)) || SupportStorageMKLDNN(in_attrs->at(1))) && dev_mask == mshadow::cpu::kDevMask) { out_attrs->at(0) = kMKLDNNStorage; *dispatch_mode = DispatchMode::kFComputeEx; diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index ed12917594e2..1ee2c9a4235c 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -25,6 +25,7 @@ #include "./elemwise_sum.h" #include "../../ndarray/ndarray_function.h" #include "../nn/mkldnn/mkldnn_ops-inl.h" +#include "../nn/mkldnn/mkldnn_base-inl.h" #include "../../common/utils.h" namespace mxnet { @@ -74,6 +75,25 @@ bool ElementWiseSumType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +#if MXNET_USE_MKLDNN == 1 +static inline bool SupportMKLDNN(const std::vector& inputs) { + for (auto &i : inputs) { + if (!SupportMKLDNN(i)) + return false; + } + return true; +} + +static inline bool SupportStorageMKLDNN(const std::vector &inputs) { + for (int i : inputs) { + if (!mxnet::SupportStorageMKLDNN(i)) + return false; + } + return true; +} + +#endif + bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -82,8 +102,7 @@ bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs, CHECK(!in_attrs->empty()); CHECK_EQ(out_attrs->size(), 1U); #if MXNET_USE_MKLDNN == 1 - if (dev_mask == mshadow::cpu::kDevMask - && common::ContainsStorage(*in_attrs, kMKLDNNStorage)) { + if (dev_mask == mshadow::cpu::kDevMask && SupportStorageMKLDNN(*in_attrs)) { *dispatch_mode = DispatchMode::kFComputeEx; (*out_attrs)[0] = kMKLDNNStorage; return true; @@ -110,7 +129,7 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs, NDArray out_nd = outputs[0]; mxnet::ndarray::ElementwiseSum(s, rsc, inputs, &out_nd); #if MXNET_USE_MKLDNN == 1 - } else if (common::ContainsStorage(inputs, kMKLDNNStorage)) { + } else if (SupportMKLDNN(inputs)) { MKLDNNSum_Forward(attrs, op_ctx, inputs, req[0], outputs[0]); #endif } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {