Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Enable mkldnn cpp-test, copy op, concat op #16503

Merged
merged 2 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,25 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!dispatched && dev_mask == mshadow::cpu::kDevMask
&& common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
&& param.dim > 0) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
#endif
#endif // MXNET_USE_MKLDNN == 100
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
#endif // MXNET_USE_MKLDNN == 100
return dispatched;
}

Expand All @@ -224,37 +224,37 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
if (dev_mask == mshadow::cpu::kDevMask
&& common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
&& param.dim > 0)
wanted_mode = DispatchMode::kFComputeEx;
else
#endif
#endif // MXNET_USE_MKLDNN == 100
wanted_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
wanted_mode = DispatchMode::kFComputeFallback;
#endif
#endif // MXNET_USE_MKLDNN == 100
return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
for (auto &arr : arrs) {
if (arr.IsView()) return false;
if (arr.dtype() != mshadow::kFloat32) return false;
// DO not support zero-size tensors.
if (arr.shape().Size() == 0) return false;
int ndim = arr.shape().ndim();
const int mkldnn_ndims = arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims;
const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
}
return true;
}
#endif
#endif // MXNET_USE_MKLDNN == 100
static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& op_ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -267,20 +267,20 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
outputs[0].storage_type() == kCSRStorage) {
ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
} else if (SupportMKLDNNConcat(inputs)) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
} else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
FallBackCompute(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
#endif
#endif // MXNET_USE_MKLDNN == 100
} else {
LogUnimplementedOp(attrs, op_ctx, inputs, req, outputs);
}
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -294,19 +294,19 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
}
FallBackCompute(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif
#endif // MXNET_USE_MKLDNN == 100

struct ConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
for (size_t i = 0; i < n->inputs.size(); i++) {
heads.push_back(n->inputs[i]);
}
#endif
#endif // MXNET_USE_MKLDNN == 100
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
Expand Down Expand Up @@ -381,12 +381,12 @@ Example::
[ 5., 5., 8., 8.]]

)code" ADD_FILELINE)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<bool>("TIsMKLDNN", true)
#endif
#endif // MXNET_USE_MKLDNN == 100
CONCAT_FORWARD_ATTRS
.set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
Expand All @@ -398,29 +398,29 @@ NNVM_REGISTER_OP(_backward_Concat)
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
#endif // MXNET_USE_MKLDNN == 100
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
#endif
#endif // MXNET_USE_MKLDNN == 100
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

// _rnn_param_concat is a custom concat op with specialized infer_shape,
// which handles the case where the first one or two inputs may have
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
.add_alias("_npi_rnn_param_concat")
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
#endif // MXNET_USE_MKLDNN == 100
CONCAT_FORWARD_ATTRS
.set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

} // namespace op
} // namespace mxnet
#endif
#endif // MXNET_USE_MKLDNN == 100
16 changes: 8 additions & 8 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ static void CopyEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
const auto in_stype = inputs[0].storage_type();
const auto out_stype = outputs[0].storage_type();
if (inputs[0].IsMKLDNNData()) {
Expand All @@ -217,7 +217,7 @@ static void CopyEx(const nnvm::NodeAttrs& attrs,
FallBackCompute(UnaryOp::IdentityCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
#endif
#endif // MXNET_USE_MKLDNN == 100
UnaryOp::IdentityComputeEx<cpu>(attrs, ctx, inputs, req, outputs);
}

Expand All @@ -230,15 +230,15 @@ static inline bool CopyStorageType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1);
bool ret = ElemwiseStorageType<1, 1, false, true, true>(attrs, dev_mask, dispatch_mode,
in_attrs, out_attrs);
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
// We have to make sure all inputs are default layouts. Otherwise, we might
// want to fallback.
if (dev_mask == mshadow::cpu::kDevMask
&& in_attrs->at(0) == kDefaultStorage
&& out_attrs->at(0) == kDefaultStorage) {
*dispatch_mode = DispatchMode::kFComputeEx;
}
#endif
#endif // MXNET_USE_MKLDNN == 100
return ret;
}

Expand All @@ -248,12 +248,12 @@ MXNET_OPERATOR_REGISTER_UNARY(_copy)
.set_attr<FInferStorageType>("FInferStorageType", CopyStorageType)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", CopyEx)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<bool>("TIsMKLDNN", true)
#endif
#endif // MXNET_USE_MKLDNN == 100
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
Expand All @@ -271,11 +271,11 @@ NNVM_REGISTER_OP(_backward_copy)
.set_attr<FInferStorageType>("FInferStorageType", CopyStorageType)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", CopyEx)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
}) // MXNET_USE_MKLDNN == 100
#endif
.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
[](const NodeAttrs& attrs){
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/include/test_core_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
keys.emplace_back(i_iter->first.c_str());
values.emplace_back(i_iter->second.c_str());
}
return imperative::ParseAttrs(op, op->num_inputs, count, &keys[0], &values[0]);
return imperative::ParseAttrs(op, op->num_inputs, count, keys.data(), values.data());
}

/*!
Expand Down
Loading