From 922b6162e8b62af98fe211d1e24f1aa10716a0e6 Mon Sep 17 00:00:00 2001 From: Wuxun Zhang Date: Fri, 11 Oct 2019 15:07:24 +0800 Subject: [PATCH] [mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258) * Add mkldnn 1.0 support for reshape/flatten/expanddims ops * improve log & modify definition location of args_map_ * fix comments * rebase code * trigger CI * trigger CI * trigger CI * trigger CI --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_expand_dims.cc | 70 +++++++++++ src/operator/nn/mkldnn/mkldnn_flatten-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_flatten.cc | 6 +- src/operator/nn/mkldnn/mkldnn_ops-inl.h | 28 +++-- src/operator/nn/mkldnn/mkldnn_reshape-inl.h | 33 ++++- src/operator/nn/mkldnn/mkldnn_reshape.cc | 124 +++++++------------ src/operator/tensor/matrix_op-inl.h | 14 +++ src/operator/tensor/matrix_op.cc | 54 ++++++-- 9 files changed, 225 insertions(+), 108 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_expand_dims.cc diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index e4c4b98e7dff..c93cdb4b730f 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -189,7 +189,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input) bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); -bool SupportMKLDNNReshape(const ReshapeParam ¶m, const NDArray &data); +bool SupportMKLDNNReshape(const NDArray &in_data, const NDArray &out_data); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_expand_dims.cc b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc new file mode 100644 index 000000000000..dcd85f1cf60c --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_expand_dims.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_expand_dims.cc + * \brief Implement expand_dims operator via MKL-DNN reorder primitive + * \author Wuxun Zhang +*/ + +#if MXNET_USE_MKLDNN == 100 + +#include "mkldnn_reshape-inl.h" + +namespace mxnet { +namespace op { + +class MKLDNNExpandDimsFwd : public MKLDNNReshapeFwd { + public: + explicit MKLDNNExpandDimsFwd(const OpReqType &req, + const NDArray &input, + const NDArray &output) + : MKLDNNReshapeFwd(req, input, output) {} +}; + +typedef ParamOpSign MKLDNNExpandDimsSignature; + +void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output) { + const ExpandDimParam& param = nnvm::get(attrs.parsed); + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; + + auto fwd = GetCachedForward(param, req, input, output); + + auto ws_size = fwd.GetWorkspaceSize(); + void* ws_ptr = nullptr; + if (ws_size) { + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor ws = ctx.requested[0] + .get_space_typed(mshadow::Shape1(ws_size), s); + ws_ptr = reinterpret_cast(ws.dptr_); + } + + fwd.Execute(input, output, req, ws_ptr); +} + +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h index ae890d8f3d91..89e52cc50988 100644 --- a/src/operator/nn/mkldnn/mkldnn_flatten-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_flatten-inl.h @@ -25,7 +25,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "mkldnn_reshape-inl.h" diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc index 4090eb026cfc..4058399ab3fe 100644 --- a/src/operator/nn/mkldnn/mkldnn_flatten.cc +++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc @@ -19,11 +19,11 @@ /*! * \file mkldnn_flatten.cc - * \brief Implement flatten operator by using mkldnn reorder primitive + * \brief Implement flatten operator via using MKL-DNN reorder primitive * \author Wuxun Zhang */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "mkldnn_flatten-inl.h" @@ -70,7 +70,7 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, ws_ptr = reinterpret_cast(ws.dptr_); } - fwd.Execute(input, output, ws_ptr); + fwd.Execute(input, output, req, ws_ptr); } } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 793aad7a60f5..ec97c9306076 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -63,18 +63,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs); - -void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); - -void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &input, - const OpReqType &req, - const NDArray &output); #endif #if MXNET_USE_MKLDNN == 100 @@ -142,6 +130,22 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const NDArray &data, const OpReqType &req, const NDArray &output); + +void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); +void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); +void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &input, + const OpReqType &req, + const NDArray &output); #endif } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h index 63e367b4dc7f..aa0f11ca7afb 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include "mkldnn_base-inl.h" #include "../../tensor/matrix_op-inl.h" @@ -36,7 +36,6 @@ namespace op { class MKLDNNReshapeFwd { protected: - std::shared_ptr data_; std::shared_ptr out_; std::shared_ptr temp_; std::vector prims_; @@ -47,15 +46,39 @@ class MKLDNNReshapeFwd { const NDArray &input, const NDArray &output); int GetWorkspaceSize(); - void SetNewMem(const NDArray &input, - const NDArray &output, - void* workspace = nullptr); void Execute(const NDArray &input, const NDArray &output, + const OpReqType &req, void* workspace = nullptr); }; typedef ParamOpSign MKLDNNReshapeSignature; + +template +MKLDNNOpFwdType &GetCachedForward(const ParamType& param, + const OpReqType &req, + const NDArray &input, + const NDArray &output) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fwds; +#else + static MX_THREAD_LOCAL std::unordered_map fwds; +#endif + MKLDNNSigatureType key(param); + key.AddSign(req); + key.AddSign(input); + key.AddSign(output); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNOpFwdType fwd(req, input, output); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, const OpReqType &req, const NDArray &input, diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc index 063c85dae39a..d180125b16bb 100644 --- a/src/operator/nn/mkldnn/mkldnn_reshape.cc +++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc @@ -23,7 +23,7 @@ * \author Tao Lv */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include "mkldnn_reshape-inl.h" @@ -31,13 +31,14 @@ namespace mxnet { namespace op { -bool SupportMKLDNNReshape(const ReshapeParam ¶m, - const NDArray &data) { - auto data_ndim = data.shape().ndim(); +bool SupportMKLDNNReshape(const NDArray &in_data, + const NDArray &out_data) { + auto in_ndim = in_data.shape().ndim(); + auto out_ndim = out_data.shape().ndim(); - if (data_ndim > 4 || - data.dtype() != mshadow::kFloat32 || - param.shape.ndim() > 4) + if (in_ndim > 4 || + in_data.dtype() != mshadow::kFloat32 || + out_ndim > 4) return false; return true; @@ -48,21 +49,16 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, const NDArray &output) { auto engine = CpuEngine::Get()->get_engine(); - // data_ + // source auto in_mem = input.GetMKLDNNData(); - auto in_pd = in_mem->get_primitive_desc(); - data_ = std::make_shared(in_pd, nullptr); + auto in_md = in_mem->get_desc(); // temp_ - auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end()); - auto temp_type = static_cast(in_pd.desc().data.data_type); - auto temp_fmt = static_cast(GetDefaultFormat(in_pd.desc())); - auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt); - auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine); - temp_ = std::make_shared(temp_pd, nullptr); + auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md)); + temp_ = std::make_shared(temp_md, engine, nullptr); // destination - out_ = std::make_shared(temp_pd, nullptr); + out_ = std::make_shared(temp_md, engine, nullptr); if (req == kWriteInplace) { // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with @@ -70,17 +66,17 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, // address with input buffer. // If the input has default layout, then nothing need to do. if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back + prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back needInvalidateInput = true; } } else if (req == kWriteTo) { if (input.IsMKLDNNData()) { - prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default - prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer + prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default + prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer needInvalidateInput = false; } else { - prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output + prims_.push_back(mkldnn::reorder(*in_mem, *out_)); // copy directly from input to output needInvalidateInput = false; } } else { @@ -89,42 +85,36 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req, } int MKLDNNReshapeFwd::GetWorkspaceSize() { - return temp_ ? temp_->get_primitive_desc().get_size() : 0; -} - -void MKLDNNReshapeFwd::SetNewMem(const NDArray &input, - const NDArray &output, - void* workspace) { - if (input.IsMKLDNNData()) { - this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle()); - } else { - MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, { - this->data_->set_data_handle(input.data().dptr()); - }) - } - - if (output.IsMKLDNNData()) { - this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle()); - } else { - MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, { - this->out_->set_data_handle(output.data().dptr()); - }) - } - - if (workspace) { - this->temp_->set_data_handle(workspace); - } + return temp_ ? temp_->get_desc().get_size() : 0; } void MKLDNNReshapeFwd::Execute(const NDArray &input, const NDArray &output, + const OpReqType &req, void* workspace) { - // set memory handles - SetNewMem(input, output, workspace); - // register primitives auto stream = MKLDNNStream::Get(); - for (auto &v : this->prims_) { - stream->RegisterPrim(v); + auto in_mem = input.GetMKLDNNData(); + // register primitives and arguments + std::vector args_map; + size_t prims_size = prims_.size(); + if (prims_size == 1) { + args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, + {MKLDNN_ARG_TO, *output.GetMKLDNNData()}}); + } else if (prims_size == 2) { + if (workspace) { + temp_->set_data_handle(workspace); + } + args_map.push_back({{MKLDNN_ARG_FROM, *in_mem}, + {MKLDNN_ARG_TO, *temp_}}); + args_map.push_back({{MKLDNN_ARG_FROM, *temp_}, + {MKLDNN_ARG_TO, *output.GetMKLDNNData()}}); + } else { + CHECK(prims_size == 0 && req != kWriteTo) + << "kWriteTo should never reach here."; + } + + for (size_t i = 0; i < prims_size; i++) { + stream->RegisterPrimArgs(prims_[i], args_map[i]); } stream->Submit(); // invalidate mkldnn memory in input @@ -133,30 +123,6 @@ void MKLDNNReshapeFwd::Execute(const NDArray &input, } } -MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param, - const OpReqType &req, - const NDArray &input, - const NDArray &output) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; -#else - static MX_THREAD_LOCAL std::unordered_map fwds; -#endif - MKLDNNReshapeSignature key(param); - key.AddSign(req); - key.AddSign(input); - key.AddSign(output); - - auto it = fwds.find(key); - if (it == fwds.end()) { - MKLDNNReshapeFwd fwd(req, input, output); - it = AddToCache(&fwds, key, fwd); - } - return it->second; -} - void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &input, @@ -166,7 +132,9 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, if (req == kNullOp) return; CHECK_NE(req, kAddTo) << "kAddTo is not supported yet"; - auto fwd = GetReshapeForward(param, req, input, output); + auto fwd = GetCachedForward(param, req, input, output); + auto ws_size = fwd.GetWorkspaceSize(); void* ws_ptr = nullptr; if (ws_size) { @@ -176,7 +144,7 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs, ws_ptr = reinterpret_cast(ws.dptr_); } - fwd.Execute(input, output, ws_ptr); + fwd.Execute(input, output, req, ws_ptr); } } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5a2bd036c22b..3f1a5f83cec6 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -394,6 +394,10 @@ struct ExpandDimParam : public dmlc::Parameter { "the input `NDArray`'s dimension is `ndim`, the range of " "the inserted axis is `[-ndim, ndim]`"); } + + bool operator==(const ExpandDimParam &other) const { + return this->axis == other.axis; + } }; @@ -2936,6 +2940,16 @@ struct hash { return ret; } }; + +template<> +struct hash { + size_t operator()(const mxnet::op::ExpandDimParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.axis); + return ret; + } +}; + } // namespace std #endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_ diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index a4f0db0140e4..6bf1ec0c5d5c 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -25,9 +25,11 @@ // this will be invoked by gcc and compile CPU version #include "./matrix_op-inl.h" #include "./elemwise_unary_op.h" +#if MXNET_USE_MKLDNN == 100 #include "../nn/mkldnn/mkldnn_ops-inl.h" #include "../nn/mkldnn/mkldnn_base-inl.h" #include "../nn/mkldnn/mkldnn_slice-inl.h" +#endif namespace mxnet { namespace op { @@ -105,19 +107,18 @@ DMLC_REGISTER_PARAMETER(SqueezeParam); DMLC_REGISTER_PARAMETER(DepthToSpaceParam); DMLC_REGISTER_PARAMETER(SplitParam); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const ReshapeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); // If inputs are supposed to be in MKLDNN format and // MKLDNNsupport the data type or the shape. Then convert // it to the output format and shape - if (SupportMKLDNNReshape(param, inputs[0])) { + if (SupportMKLDNNReshape(inputs[0], outputs[0])) { MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]); return; } @@ -207,7 +208,7 @@ If the argument `reverse` is set to 1, then the special values are inferred from .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) .set_attr("FCompute", UnaryOp::IdentityCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ReshapeComputeExCPU) .set_attr("FInferStorageType", ReshapeStorageType) @@ -233,7 +234,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 auto data_ndim = inputs[0].shape().ndim(); if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) { MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]); @@ -248,7 +249,7 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs, #endif } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -294,13 +295,13 @@ Example:: .set_num_outputs(1) .set_attr("FInferShape", FlattenShape) .set_attr("FInferType", ElemwiseType<1, 1>) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FInferStorageType", FlattenStorageType) #endif .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_copy" }) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FComputeEx", FlattenEx) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -411,6 +412,33 @@ Examples:: .add_arguments(TransposeParam::__FIELDS__()); +#if MXNET_USE_MKLDNN == 100 +static void ExpandDimEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + auto data_ndim = inputs[0].shape().ndim(); + if (data_ndim <= 3 && inputs[0].dtype() == mshadow::kFloat32) { + MKLDNNExpandDimsForward(attrs, ctx, inputs[0], req[0], outputs[0]); + return; + } + FallBackCompute(UnaryOp::IdentityCompute, attrs, ctx, inputs, req, outputs); +} + +inline static bool ExpandDimStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} +#endif + NNVM_REGISTER_OP(expand_dims) .add_alias("_npi_expand_dims") .describe(R"code(Inserts a new axis of size 1 into the array shape @@ -424,6 +452,9 @@ will return a new array with shape ``(2,1,3,4)``. .set_attr_parser(ParamParser) .set_attr("FInferShape", ExpandDimShape) .set_attr("FInferType", ElemwiseType<1, 1>) +#if MXNET_USE_MKLDNN == 100 +.set_attr("FInferStorageType", ExpandDimStorageType) +#endif .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; @@ -434,6 +465,13 @@ will return a new array with shape ``(2,1,3,4)``. }) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) .set_attr("FCompute", UnaryOp::IdentityCompute) +#if MXNET_USE_MKLDNN == 100 +.set_attr("FComputeEx", ExpandDimEx) +.set_attr("TIsMKLDNN", true) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(ExpandDimParam::__FIELDS__());