Skip to content

Commit

Permalink
Use MKLDNN sum in more cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Dec 15, 2017
1 parent d6d74f4 commit 7a21ebc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
7 changes: 6 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,18 @@ static inline bool SupportMKLDNNArray(int dtype, const TShape &shape) {
return support;
}

static inline bool SupportStorageMKLDNN(int stype) {
return stype == kMKLDNNStorage || stype == kDefaultStorage;
}

static inline bool SupportMKLDNN(int dtype, const TShape &shape) {
int ndim = shape.ndim();
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNN(const NDArray &input) {
return SupportMKLDNN(input.dtype(), input.shape());
return SupportMKLDNN(input.dtype(), input.shape())
&& SupportStorageMKLDNN(input.storage_type());
}

static inline bool SupportMKLDNNConv(const NDArray &input) {
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "./elemwise_unary_op.h"
#include "./elemwise_binary_op-inl.h"
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"

namespace mxnet {
namespace op {
Expand All @@ -37,8 +38,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
#if MXNET_USE_MKLDNN == 1
if (inputs[0].storage_type() == kMKLDNNStorage
|| inputs[1].storage_type() == kMKLDNNStorage) {
if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) {
MKLDNNSum_Forward(attrs, ctx, inputs, req[0], outputs[0]);
return;
} else if (inputs[0].storage_type() == kDefaultStorage
Expand Down Expand Up @@ -68,7 +68,7 @@ static inline bool ElemwiseAddStorageType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
#if MXNET_USE_MKLDNN == 1
if ((in_attrs->at(0) == kMKLDNNStorage || in_attrs->at(1) == kMKLDNNStorage)
if ((SupportStorageMKLDNN(in_attrs->at(0)) || SupportStorageMKLDNN(in_attrs->at(1)))
&& dev_mask == mshadow::cpu::kDevMask) {
out_attrs->at(0) = kMKLDNNStorage;
*dispatch_mode = DispatchMode::kFComputeEx;
Expand Down
25 changes: 22 additions & 3 deletions src/operator/tensor/elemwise_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "./elemwise_sum.h"
#include "../../ndarray/ndarray_function.h"
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"
#include "../../common/utils.h"

namespace mxnet {
Expand Down Expand Up @@ -74,6 +75,25 @@ bool ElementWiseSumType(const nnvm::NodeAttrs& attrs,
attrs, in_attrs, out_attrs, -1);
}

#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNN(const std::vector<NDArray>& inputs) {
for (auto &i : inputs) {
if (!SupportMKLDNN(i))
return false;
}
return true;
}

static inline bool SupportStorageMKLDNN(const std::vector<int> &inputs) {
for (int i : inputs) {
if (!mxnet::SupportStorageMKLDNN(i))
return false;
}
return true;
}

#endif

bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand All @@ -82,8 +102,7 @@ bool ElementWiseSumForwardInferStorageType(const nnvm::NodeAttrs& attrs,
CHECK(!in_attrs->empty());
CHECK_EQ(out_attrs->size(), 1U);
#if MXNET_USE_MKLDNN == 1
if (dev_mask == mshadow::cpu::kDevMask
&& common::ContainsStorage(*in_attrs, kMKLDNNStorage)) {
if (dev_mask == mshadow::cpu::kDevMask && SupportStorageMKLDNN(*in_attrs)) {
*dispatch_mode = DispatchMode::kFComputeEx;
(*out_attrs)[0] = kMKLDNNStorage;
return true;
Expand All @@ -110,7 +129,7 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs,
NDArray out_nd = outputs[0];
mxnet::ndarray::ElementwiseSum<cpu>(s, rsc, inputs, &out_nd);
#if MXNET_USE_MKLDNN == 1
} else if (common::ContainsStorage(inputs, kMKLDNNStorage)) {
} else if (SupportMKLDNN(inputs)) {
MKLDNNSum_Forward(attrs, op_ctx, inputs, req[0], outputs[0]);
#endif
} else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
Expand Down

0 comments on commit 7a21ebc

Please sign in to comment.