Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0]enable mkldnn concat #16507

Merged
merged 3 commits into from
Oct 17, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!dispatched && dev_mask == mshadow::cpu::kDevMask
&& common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
&& param.dim > 0) {
Expand All @@ -211,7 +211,7 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
Expand All @@ -224,7 +224,7 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
if (dev_mask == mshadow::cpu::kDevMask
Expand All @@ -234,22 +234,22 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
else
#endif
wanted_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
wanted_mode = DispatchMode::kFComputeFallback;
#endif
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
for (auto &arr : arrs) {
if (arr.IsView()) return false;
if (arr.dtype() != mshadow::kFloat32) return false;
// DO not support zero-size tensors.
if (arr.shape().Size() == 0) return false;
int ndim = arr.shape().ndim();
const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims;
const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
}
return true;
Expand All @@ -267,7 +267,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
outputs[0].storage_type() == kCSRStorage) {
ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
} else if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
Expand All @@ -280,7 +280,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
}
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -302,7 +302,7 @@ struct ConcatGrad {
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
for (size_t i = 0; i < n->inputs.size(); i++) {
heads.push_back(n->inputs[i]);
}
Expand Down Expand Up @@ -381,7 +381,7 @@ Example::
[ 5., 5., 8., 8.]]
)code" ADD_FILELINE)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
Expand All @@ -398,14 +398,14 @@ NNVM_REGISTER_OP(_backward_Concat)
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
#endif
Expand All @@ -416,7 +416,7 @@ NNVM_REGISTER_OP(_backward_Concat)
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
.add_alias("_npi_rnn_param_concat")
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
Expand Down