Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0]enable mkldnn concat (#16507)
Browse files Browse the repository at this point in the history
* enable mkldnn concat

* trigger CI

* trigger CI
  • Loading branch information
rongzha1 authored and pengzhao-intel committed Oct 17, 2019
1 parent 10db1b1 commit b3e02b1
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!dispatched && dev_mask == mshadow::cpu::kDevMask
&& common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
&& param.dim > 0) {
Expand All @@ -211,7 +211,7 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
Expand All @@ -224,7 +224,7 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
if (dev_mask == mshadow::cpu::kDevMask
Expand All @@ -234,22 +234,22 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
else
#endif
wanted_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
wanted_mode = DispatchMode::kFComputeFallback;
#endif
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
for (auto &arr : arrs) {
if (arr.IsView()) return false;
if (arr.dtype() != mshadow::kFloat32) return false;
// DO not support zero-size tensors.
if (arr.shape().Size() == 0) return false;
int ndim = arr.shape().ndim();
const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims;
const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
}
return true;
Expand All @@ -267,7 +267,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
outputs[0].storage_type() == kCSRStorage) {
ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
} else if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
Expand All @@ -280,7 +280,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
}
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -302,7 +302,7 @@ struct ConcatGrad {
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
for (size_t i = 0; i < n->inputs.size(); i++) {
heads.push_back(n->inputs[i]);
}
Expand Down Expand Up @@ -381,7 +381,7 @@ Example::
[ 5., 5., 8., 8.]]
)code" ADD_FILELINE)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
Expand All @@ -398,14 +398,14 @@ NNVM_REGISTER_OP(_backward_Concat)
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
#endif
Expand All @@ -416,7 +416,7 @@ NNVM_REGISTER_OP(_backward_Concat)
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
.add_alias("_npi_rnn_param_concat")
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
Expand Down

0 comments on commit b3e02b1

Please sign in to comment.