diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 7f8638630145..2c4127b9a088 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -332,166 +332,50 @@ class LeakyReLUOp : public Operator { }; // class LeakyReLUOp template -Operator* CreateOp(LeakyReLUParam type, int dtype); +void LeakyReLUCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LeakyReLUParam ¶m = nnvm::get(attrs.parsed); + const std::vector no_use_but_adapt_origin_api; + size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1; + CHECK_EQ(inputs.size(), expected); -#if DMLC_USE_CXX11 -class LeakyReLUProp : public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { - using namespace mshadow; - if (param_.act_type == leakyrelu::kPReLU) { - CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]"; - } else { - CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; - } - const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData); - if (!mxnet::ndim_is_known(dshape)) return false; - if (param_.act_type == leakyrelu::kPReLU) { - const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma); - if (!mxnet::ndim_is_known(gshape)) { - in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1])); - } - if (dshape == gshape) { - SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape); - } - } - out_shape->clear(); - out_shape->push_back(dshape); - if (param_.act_type == leakyrelu::kRReLU) { - out_shape->push_back(dshape); - } - return true; - } - - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { - int dtype = -1; - for (const int& type : *in_type) { - type_assign(&dtype, type); - } - for (const int& type : *out_type) { - type_assign(&dtype, type); - } - - for (size_t i = 0; i < in_type->size(); ++i) { - TYPE_ASSIGN_CHECK(*in_type, i, dtype); - } - for (size_t i = 0; i < out_type->size(); ++i) { - TYPE_ASSIGN_CHECK(*out_type, i, dtype); - } - return dtype != -1; - } - - OperatorProperty* Copy() const override { - auto ptr = new LeakyReLUProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "LeakyReLU"; - } - - // decalre dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - if (param_.act_type == leakyrelu::kPReLU) { - return {out_grad[leakyrelu::kOut], - out_data[leakyrelu::kOut], - in_data[leakyrelu::kData], - in_data[leakyrelu::kGamma]}; - } else if (param_.act_type == leakyrelu::kRReLU) { - return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kMask], out_data[leakyrelu::kOut]}; - } else { - return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kData]}; - } - } + MSHADOW_REAL_TYPE_SWITCH(inputs[leakyrelu::kData].type_flag_, DType, { + LeakyReLUOp op(param); + op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api); + }); +} - std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const override { - return {{out_grad[leakyrelu::kOut], in_grad[leakyrelu::kData]}}; - } - - std::vector > ForwardInplaceOption( - const std::vector &in_data, - const std::vector &out_data) const override { - if (param_.act_type == leakyrelu::kPReLU) { - return {}; - } else { - return {{in_data[leakyrelu::kData], out_data[leakyrelu::kOut]}}; - } - } - - std::vector ListArguments() const override { - if (param_.act_type == leakyrelu::kPReLU) { - return {"data", "gamma"}; - } else { - return {"data"}; - } - } - - std::vector ListOutputs() const override { - if (param_.act_type == leakyrelu::kRReLU) { - return {"output", "mask"}; - } else { - return {"output"}; - } - } - - int NumOutputs() const override { - if (param_.act_type == leakyrelu::kRReLU) { - return 2; - } else { - return 1; - } - } - - int NumVisibleOutputs() const override { - return 1; - } - - std::vector ForwardResource( - const mxnet::ShapeVector &in_shape) const override { - if (param_.act_type == leakyrelu::kRReLU) { - return {ResourceRequest::kRandom}; - } else { - return std::vector(); - } - } +template +void LeakyReLUGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + const std::vector no_use_but_adapt_origin_api; + // inputs: out_grad, input_data, input_gamma, output, output_mask + size_t expected_in = param.act_type == leakyrelu::kPReLU ? 2 : 1; + size_t expected_out = param.act_type == leakyrelu::kRReLU ? 2 : 1; - std::vector BackwardResource( - const mxnet::ShapeVector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } + CHECK_GE(inputs.size(), 1 + expected_in + expected_out); + std::vector out_grad{inputs[0]}; + std::vector in_data(inputs.begin() + 1, + inputs.begin() + 1 + expected_in); + std::vector out_data(inputs.begin() + 1 + expected_in, + inputs.begin() + 1 + expected_in + expected_out); - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not Implemented."; - return NULL; - } + CHECK_EQ(req.size(), outputs.size()); + int dtype = inputs[0].type_flag_; + const std::vector &in_grad = outputs; - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + LeakyReLUOp op(param); + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api); + }); +} - private: - LeakyReLUParam param_; -}; -#endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc index c25833b799d0..4d1c5ca10a30 100644 --- a/src/operator/leaky_relu.cc +++ b/src/operator/leaky_relu.cc @@ -25,27 +25,123 @@ */ #include "./leaky_relu-inl.h" +#if MXNET_USE_MKLDNN == 1 +#include "./nn/mkldnn/mkldnn_base-inl.h" +#include "./nn/mkldnn/mkldnn_ops-inl.h" +#endif // MXNET_USE_MKLDNN == 1 #include namespace mxnet { namespace op { -template<> -Operator *CreateOp(LeakyReLUParam param, int dtype) { - Operator* op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new LeakyReLUOp(param); - }); - return op; + +DMLC_REGISTER_PARAMETER(LeakyReLUParam); + +static bool LeakyReLUType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + int dtype = -1; + for (const int& type : *in_type) { + type_assign(&dtype, type); + } + for (const int& type : *out_type) { + type_assign(&dtype, type); + } + for (size_t i = 0; i < in_type->size(); ++i) { + TYPE_ASSIGN_CHECK(*in_type, i, dtype); + } + for (size_t i = 0; i < out_type->size(); ++i) { + TYPE_ASSIGN_CHECK(*out_type, i, dtype); + } + return dtype != -1; } -Operator *LeakyReLUProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const { - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); +static bool LeakyReLUShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + const LeakyReLUParam ¶m_ = nnvm::get(attrs.parsed); + if (param_.act_type == leakyrelu::kPReLU) { + CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]"; + } else { + CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; + } + const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData); + if (!mxnet::ndim_is_known(dshape)) return false; + if (param_.act_type == leakyrelu::kPReLU) { + const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma); + if (!mxnet::ndim_is_known(gshape)) { + in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1])); + } + if (dshape == gshape) { + SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape); + } + } + out_shape->clear(); + out_shape->push_back(dshape); + if (param_.act_type == leakyrelu::kRReLU) { + out_shape->push_back(dshape); + } + return true; } -DMLC_REGISTER_PARAMETER(LeakyReLUParam); +#if MXNET_USE_MKLDNN == 1 +static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1; + CHECK_EQ(inputs.size(), expected); + if (SupportMKLDNNLeakyRelu(param, inputs[0])) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNLeakyReluForward(attrs, ctx, inputs[0], req[0], outputs[0]); + MKLDNN_OPCHECK_RUN(LeakyReLUCompute, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(LeakyReLUCompute, attrs, ctx, inputs, req, outputs); +} + +void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + if (SupportMKLDNNLeakyRelu(param, inputs[0])) { + std::vector in_data{inputs[0], inputs[1]}; + MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); + MKLDNNLeakyReluBackward(attrs, ctx, in_data, req[0], outputs[0]); + MKLDNN_OPCHECK_RUN(LeakyReLUGradCompute, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(LeakyReLUGradCompute, attrs, ctx, inputs, req, outputs); +} + +inline static bool LeakyReLUStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1; + CHECK_EQ(in_attrs->size(), expected); + return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNLeakyRelu(param), + dispatch_mode, in_attrs, out_attrs); +} -MXNET_REGISTER_OP_PROPERTY(LeakyReLU, LeakyReLUProp) +inline static bool BackwardLeakyReLUStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNLeakyRelu(param), + dispatch_mode, in_attrs, out_attrs); +} +#endif // MXNET_USE_MKLDNN == 1 + +NNVM_REGISTER_OP(LeakyReLU) .describe(R"code(Applies Leaky rectified linear unit activation element-wise to the input. Leaky ReLUs attempt to fix the "dying ReLU" problem by allowing a small `slope` @@ -63,15 +159,45 @@ The following modified ReLU Activation functions are supported: *(lower_bound+upper_bound)/2* for inference. )code" ADD_FILELINE) -.add_argument("data", "NDArray-or-Symbol", "Input data to activation function.") -.add_argument("gamma", "NDArray-or-Symbol", - "Slope parameter for PReLU. Only required " - "when act_type is 'prelu'. It should be either a vector of size 1, " - "or the same size as the second dimension of data.") -.add_arguments(LeakyReLUParam::__FIELDS__()); - -NNVM_REGISTER_OP(LeakyReLU) .add_alias("_npx_leaky_relu") +.set_num_inputs([](const NodeAttrs& attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + return param.act_type == leakyrelu::kPReLU ? 2 : 1; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + return param.act_type == leakyrelu::kRReLU ? 2 : 1; +}) +.set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", LeakyReLUStorageType) +#endif +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + return param.act_type == leakyrelu::kPReLU ? std::vector{"data", "gamma"} + : std::vector{"data"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + return param.act_type == leakyrelu::kRReLU ? std::vector{"output", "mask"} + : std::vector{"output"}; +}) +.set_attr("FInferShape", LeakyReLUShape) +.set_attr("FInferType", LeakyReLUType) +.set_attr("FCompute", LeakyReLUCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", LeakyReLUComputeExCPU) +#endif +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_LeakyReLU"}) +.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input data to activation function.") +.add_argument("gamma", "NDArray-or-Symbol", "Input data to activation function.") +.add_arguments(LeakyReLUParam::__FIELDS__()) .set_attr("FSetInputVarAttrOnCompose", [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) { if (index == 1 && var->attrs.dict.find("__init__") == var->attrs.dict.end()) { @@ -79,5 +205,27 @@ NNVM_REGISTER_OP(LeakyReLU) } }); +NNVM_REGISTER_OP(_backward_LeakyReLU) +.set_num_outputs([](const NodeAttrs& attrs) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + return param.act_type == leakyrelu::kPReLU ? 2 : 1; +}) +.set_attr("TIsBackward", true) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FInferStorageType", BackwardLeakyReLUStorageType) +#endif +.set_attr("FInplaceOption", [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr_parser(ParamParser) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FComputeEx", LeakyReLUGradComputeExCPU) +#endif +.set_attr("FCompute", LeakyReLUGradCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/leaky_relu.cu b/src/operator/leaky_relu.cu index a2e0e959a15b..fa8e95b4b952 100644 --- a/src/operator/leaky_relu.cu +++ b/src/operator/leaky_relu.cu @@ -28,14 +28,12 @@ namespace mxnet { namespace op { -template<> -Operator *CreateOp(LeakyReLUParam param, int dtype) { - Operator* op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new LeakyReLUOp(param); - }); - return op; -} + +NNVM_REGISTER_OP(LeakyReLU) +.set_attr("FCompute", LeakyReLUCompute); + +NNVM_REGISTER_OP(_backward_LeakyReLU) +.set_attr("FCompute", LeakyReLUGradCompute); } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h b/src/operator/nn/mkldnn/mkldnn_act-inl.h index 6bf30e3f3bbe..9c21b7f70f52 100644 --- a/src/operator/nn/mkldnn/mkldnn_act-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h @@ -32,22 +32,34 @@ #include #include #include "../activation-inl.h" +#include "../../leaky_relu-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" namespace mxnet { namespace op { +struct MKLDNNActParam { + mkldnn::algorithm alg; + float slope = 0.f; + + bool operator==(const MKLDNNActParam& other) const { + return this->alg == other.alg && + this->slope == other.slope; + } +}; + mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param); +mkldnn::algorithm GetMKLDNNActAlgo(const LeakyReLUParam& param); mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( - const ActivationParam& param, bool is_train, + const MKLDNNActParam& param, bool is_train, const mkldnn::memory &input_mem, int dtype); class MKLDNNActForward { public: const mkldnn::eltwise_forward::primitive_desc fwd_pd; - MKLDNNActForward(const ActivationParam& param, bool is_train, + MKLDNNActForward(const MKLDNNActParam& param, bool is_train, const NDArray &data, const mkldnn::memory &mem): fwd_pd( GetActFwdDescImpl(param, is_train, mem, data.dtype())) {} void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output); @@ -59,16 +71,31 @@ class MKLDNNActForward { std::shared_ptr out_; }; -typedef ParamOpSign MKLDNNActSignature; -MKLDNNActForward &GetActForward(const ActivationParam& param, +typedef ParamOpSign MKLDNNActSignature; +MKLDNNActForward &GetActForward(const MKLDNNActParam& param, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem); void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data); +void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); } // namespace op } // namespace mxnet +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::MKLDNNActParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, static_cast(val.alg)); + ret = dmlc::HashCombine(ret, val.slope); + return ret; + } +}; +} // namespace std + #endif // MXNET_USE_MKLDNN == 1 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index e4c829645e13..f221ddf5e345 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -57,6 +57,20 @@ bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) { return SupportMKLDNNAct(param); } +bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param) { + return param.act_type == leakyrelu::kLeakyReLU + || param.act_type == leakyrelu::kELU; +} + +bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param, const NDArray &input) { + // MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout + if ((input.shape().ndim() < 1) || + (input.shape().ndim() > 4) || + (input.dtype() != mshadow::kFloat32)) + return false; + return SupportMKLDNNLeakyRelu(param); +} + bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m) { // TODO(zhennan): Add more activation type when mkldnn supports. // Remove this when it's identity to SupportMKLDNNAct. @@ -79,18 +93,30 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { } } +mkldnn::algorithm GetMKLDNNActAlgo(const LeakyReLUParam& param) { + switch (param.act_type) { + case leakyrelu::kLeakyReLU: + return mkldnn::algorithm::eltwise_relu; + case leakyrelu::kELU: + return mkldnn::algorithm::eltwise_elu; + default: + LOG(FATAL) << "unknown activation type for LeakyReLU: " << param.act_type; + return mkldnn::algorithm::eltwise_relu; + } +} + mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl( - const ActivationParam& param, bool is_train, + const MKLDNNActParam& param, bool is_train, const mkldnn::memory &input_mem, int dtype) { mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); auto cpu_engine = data_mpd.get_engine(); - auto alg = GetMKLDNNActAlgo(param); + auto alg = param.alg; auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; - auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f); + auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, param.slope); return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine); } @@ -119,7 +145,7 @@ const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const { return *fwd_; } -MKLDNNActForward &GetActForward(const ActivationParam& param, +MKLDNNActForward &GetActForward(const MKLDNNActParam& param, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem) { #if DMLC_CXX11_THREAD_LOCAL @@ -129,7 +155,8 @@ MKLDNNActForward &GetActForward(const ActivationParam& param, #endif MKLDNNActSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(param.act_type); + key.AddSign(param.alg); + key.AddSign(param.slope); key.AddSign(in_data); auto it = fwds.find(key); @@ -144,6 +171,8 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &in_data, const OpReqType &req, const NDArray &out_data) { const ActivationParam& param = nnvm::get(attrs.parsed); + MKLDNNActParam param_; + param_.alg = GetMKLDNNActAlgo(param); NDArray in_buffer = in_data; MKLDNNStream *stream = MKLDNNStream::Get(); @@ -152,7 +181,30 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, in_buffer = in_data.Reorder2Default(); auto input_mem = in_buffer.GetMKLDNNData(); - MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem); + MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem); + auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); + fwd.SetNewMem(*input_mem, *out_mem_t.second); + stream->RegisterPrim(fwd.GetFwd()); + CommitOutput(out_data, out_mem_t); + stream->Submit(); +} + +void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data) { + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + MKLDNNActParam param_; + param_.alg = GetMKLDNNActAlgo(param); + param_.slope = param.slope; + + NDArray in_buffer = in_data; + MKLDNNStream *stream = MKLDNNStream::Get(); + + if (in_data.IsView() && in_data.IsMKLDNNData()) + in_buffer = in_data.Reorder2Default(); + + auto input_mem = in_buffer.GetMKLDNNData(); + MKLDNNActForward &fwd = GetActForward(param_, ctx, in_buffer, *input_mem); auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer); fwd.SetNewMem(*input_mem, *out_mem_t.second); stream->RegisterPrim(fwd.GetFwd()); @@ -161,29 +213,28 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, } static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl( - const ActivationParam ¶m, const mkldnn::memory &input_mem, + const MKLDNNActParam ¶m, const mkldnn::memory &input_mem, const mkldnn::memory &diff_dst_memory, int dtype) { mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc(); auto cpu_engine = data_mpd.get_engine(); - auto alg = GetMKLDNNActAlgo(param); + auto alg = param.alg; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - DType alpha = 0; mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, alpha); + alg, data_md, param.slope); mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, param.slope); mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, fw_pdesc); return bw_pdesc; }); LOG(FATAL) << "Unsupported data type for MKLDNN activation"; mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training, - alg, data_md, 0.0); + alg, data_md, param.slope); mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine); - mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0); + mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, param.slope); mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, fw_pdesc); return bw_pdesc; @@ -198,7 +249,7 @@ class MKLDNNActBackward { public: const mkldnn::eltwise_backward::primitive_desc pd; - explicit MKLDNNActBackward(const ActivationParam ¶m, const NDArray &data, + explicit MKLDNNActBackward(const MKLDNNActParam ¶m, const NDArray &data, const mkldnn::memory &mem, const mkldnn::memory &diff_dst_memory) : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {} @@ -229,7 +280,7 @@ class MKLDNNActBackward { const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; } }; -static inline MKLDNNActBackward &GetActBackward(const ActivationParam ¶m, +static inline MKLDNNActBackward &GetActBackward(const MKLDNNActParam ¶m, const OpContext &ctx, const NDArray &in_data, const NDArray &out_grad, @@ -269,6 +320,8 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx in_buffer = in_data.Reorder2Default(); const ActivationParam& param = nnvm::get(attrs.parsed); + MKLDNNActParam param_; + param_.alg = GetMKLDNNActAlgo(param); TmpMemMgr::Get()->Init(ctx.requested[activation::kTempSpace]); auto diff_dst_memory = out_buffer.GetMKLDNNData(); auto input_mem = in_buffer.GetMKLDNNData(); @@ -277,7 +330,7 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); MKLDNNActBackward &bwd = - GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem); + GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); MKLDNNStream *stream = MKLDNNStream::Get(); mkldnn_output_t diff_src_memory = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); @@ -287,6 +340,46 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx stream->Submit(); } +void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector& inputs, + const OpReqType &req, + const NDArray &output) { + if (req == kNullOp) { + return; + } + CHECK_GE(inputs.size(), 2U); + NDArray out_buffer = inputs[0]; + if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) + out_buffer = inputs[0].Reorder2Default(); + + NDArray in_buffer = inputs[1]; + if (inputs[1].IsView() && inputs[1].IsMKLDNNData()) + in_buffer = inputs[1].Reorder2Default(); + + const LeakyReLUParam& param = nnvm::get(attrs.parsed); + MKLDNNActParam param_; + param_.alg = GetMKLDNNActAlgo(param); + param_.slope = param.slope; + + TmpMemMgr::Get()->Init(ctx.requested[leakyrelu::kRandom]); + auto diff_dst_memory = out_buffer.GetMKLDNNData(); + auto input_mem = in_buffer.GetMKLDNNData(); + // We need to make sure the two inputs to eltwise_backward has the same memory + // descriptor. Otherwise, the perf will suffer. + if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc()) + input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc()); + MKLDNNActBackward &bwd = + GetActBackward(param_, ctx, in_buffer, out_buffer, *input_mem); + MKLDNNStream *stream = MKLDNNStream::Get(); + mkldnn_output_t diff_src_memory = + CreateMKLDNNMem(output, bwd.pd.diff_src_primitive_desc(), req); + bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second); + stream->RegisterPrim(bwd.GetBwd()); + CommitOutput(output, diff_src_memory); + stream->Submit(); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 34f4a0b0b062..9e8725e776e5 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -177,6 +177,7 @@ void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space); namespace op { struct ActivationParam; +struct LeakyReLUParam; struct ConvolutionParam; struct DeconvolutionParam; struct SoftmaxParam; @@ -185,6 +186,8 @@ struct TransposeParam; struct ReshapeParam; bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); +bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param); +bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param, const NDArray &input); bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m); bool SupportMKLDNNConv(const ConvolutionParam ¶ms, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 502abff6231b..b564a3318402 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -109,9 +109,15 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const NDArray &out_grad, const NDArray &in_data, const OpReqType &req, const NDArray &in_grad); +void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); +void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector& inputs, const OpReqType &req, + const NDArray &output); void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, - const mkldnn::memory &out); + const mkldnn::memory &out); void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h index fcf767adebad..002b012bc35e 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h @@ -40,18 +40,18 @@ struct MKLDNNConvFusionParam { static inline bool IsOutputUInt8(const MKLDNNConvFusionParam& param) { bool result = false; const auto& mkldnn_param = param.full_conv_param.mkldnn_param; - auto IsOutputUInt8Helper = [](const mkldnn::algorithm& act_alg) { - return (act_alg == mkldnn::algorithm::eltwise_relu || - act_alg == mkldnn::algorithm::eltwise_logistic || - act_alg == mkldnn::algorithm::eltwise_soft_relu || - act_alg == mkldnn::algorithm::eltwise_bounded_relu); + auto IsOutputUInt8Helper = [](const MKLDNNPostEltwiseParam ¶m) { + return ((param.alg == mkldnn::algorithm::eltwise_relu && param.alpha == 0.f) || + param.alg == mkldnn::algorithm::eltwise_logistic || + param.alg == mkldnn::algorithm::eltwise_soft_relu || + param.alg == mkldnn::algorithm::eltwise_bounded_relu); }; if ((!mkldnn_param.with_sum) && mkldnn_param.with_act) { CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::algorithm_undef); - result = IsOutputUInt8Helper(param.full_conv_param.act_param.alg); + result = IsOutputUInt8Helper(param.full_conv_param.act_param); } else if (mkldnn_param.with_postsum_act) { CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::algorithm_undef); - result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param.alg); + result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param); } return result; } diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index d9bfa02b8820..886a21b44fc9 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -511,7 +511,7 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { } else if (node_name == "Convolution") { param_.full_conv_param.conv_param = nnvm::get(node->attrs.parsed); - } else if (node_name == "Activation" || node_name == "clip") { + } else if (node_name == "Activation" || node_name == "LeakyReLU" || node_name == "clip") { auto &post_act_param = (param_.full_conv_param.mkldnn_param.with_act && !with_act) ? param_.full_conv_param.act_param @@ -520,6 +520,10 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) { if (node_name == "Activation") { const auto act_param = nnvm::get(node->attrs.parsed); post_act_param.alg = GetMKLDNNActAlgo(act_param); + } else if (node_name == "LeakyReLU") { + const auto act_param = nnvm::get(node->attrs.parsed); + post_act_param.alpha = act_param.slope; + post_act_param.alg = GetMKLDNNActAlgo(act_param); } else { const auto clip_param = nnvm::get(node->attrs.parsed); post_act_param.alg = mkldnn::algorithm::eltwise_bounded_relu; diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index 40b3f7c1d010..ff6589e6fb0a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -24,6 +24,7 @@ #include #include #include "../../nn/activation-inl.h" +#include "../../leaky_relu-inl.h" #include "../../nn/convolution-inl.h" #include "../../nn/mkldnn/mkldnn_ops-inl.h" #include "../../tensor/matrix_op-inl.h" @@ -121,6 +122,15 @@ class SgMKLDNNConvSelector : public SubgraphSelector { status_ = kSuccess; return true; } + } else if ((!disable_conv_act_) && new_node.op()->name == "LeakyReLU") { + const LeakyReLUParam ¶m = + nnvm::get(new_node.attrs.parsed); + if (param.act_type == leakyrelu::kLeakyReLU) { + matched_list_.push_back(&new_node); + // not support conv+relu+sum yet. + status_ = kSuccess; + return true; + } } else if ((!disable_conv_act_) && new_node.op()->name == "clip") { if (!(quantize_ && (status_ == kSum))) { // TODO(zhennan): doesn't support int8 conv+sum+relu6 at moment. To support this, we @@ -208,7 +218,7 @@ class SgMKLDNNConvProperty : public SubgraphProperty { n->attrs.dict["with_sum"] = "true"; _with_sum = true; - } else if (sub_name == "Activation" || sub_name == "clip") { + } else if (sub_name == "Activation" || sub_name == "LeakyReLU" || sub_name == "clip") { node_name << "act_"; if (!_with_sum) { n->attrs.dict["with_act"] = "true"; diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 8ba6d6183c86..dc494b25b549 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -327,6 +327,8 @@ def conv_act(no_bias, data_shape, alg): kernel=(3, 3), stride=(1, 1), no_bias=no_bias) if alg == "relu6": relu = mx.symbol.clip(data=conv, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=conv, slope=0.25, act_type='leaky') else: relu = mx.symbol.Activation(data=conv, name=alg, act_type=alg) return relu, attr @@ -339,6 +341,8 @@ def conv_act_sum(no_bias, data_shape, alg): kernel=(3, 3), stride=(1, 1), no_bias=no_bias) if alg == "relu6": relu = mx.symbol.clip(data=conv, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=conv, slope=0.25, act_type='leaky') else: relu = mx.symbol.Activation(data=conv, name=alg, act_type=alg) conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, @@ -379,6 +383,8 @@ def conv_bn_act(no_bias, data_shape, alg): bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") if alg == "relu6": relu = mx.symbol.clip(data=bn1, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=bn1, slope=0.25, act_type='leaky') else: relu = mx.symbol.Activation(data=bn1, name=alg, act_type=alg) return relu, attr @@ -395,6 +401,8 @@ def conv_bn_sum_act(no_bias, data_shape, alg): sum1 = bn1 + conv1 if alg == "relu6": relu = mx.symbol.clip(data=sum1, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=sum1, slope=0.25, act_type='leaky') else: relu = mx.symbol.Activation(data=sum1, name=alg, act_type=alg) return relu, attr @@ -693,7 +701,8 @@ def test_pos_conv_act(): "sigmoid": True, "tanh": True, "softrelu": True, - "relu6": True} + "relu6": True, + "leakyrelu": True} for data_shape in DATA_SHAPE: for (alg, quantize) in act_list.items(): net, attrs = conv_act(False, data_shape, alg) @@ -731,7 +740,8 @@ def test_pos_conv_bn_act(): "sigmoid": True, "tanh": True, "softrelu": True, - "relu6": True} + "relu6": True, + "leakyrelu": True} for data_shape in DATA_SHAPE: for (alg, quantize) in act_list.items(): net, attrs = conv_bn_act(False, data_shape, alg) @@ -745,7 +755,8 @@ def test_pos_conv_bn_sum_act(): "sigmoid": True, "tanh": True, "softrelu": True, - "relu6": False} + "relu6": False, + "leakyrelu": True} for data_shape in DATA_SHAPE: for (alg, quantize) in act_list.items(): net, attrs = conv_bn_sum_act(False, data_shape, alg) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 46e976432fa8..d11443ca839a 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1223,7 +1223,7 @@ def elu(x): return [elu(x_i) for x_i in x] for test_point, ref_point in zip(elu_test(point_to_validate), elu(point_to_validate)): - assert test_point == ref_point + assert_almost_equal(test_point.asnumpy(), ref_point.asnumpy()) selu = mx.gluon.nn.SELU() def selu_test(x):