diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index c026d1f6853b..1348b4efc81b 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -30,7 +30,7 @@ #if MXNET_USE_MKLDNN == 1 #include "./mkldnn/mkldnn_base-inl.h" #include "./mkldnn/mkldnn_ops-inl.h" -#endif // MXNET_USE_MKLDNN +#endif // MXNET_USE_MKLDNN == 1 #include "../operator_common.h" #include "../../common/utils.h" @@ -40,49 +40,22 @@ namespace op { namespace activation { int GradNumInputs(int act_type) { -#if MXNET_USE_CUDNN == 1 // check activation.cu \sa ActivationGradCompute - switch (act_type) { - case kReLU: - case kSoftReLU: - return 2; - case kSoftSign: - case kTanh: - case kSigmoid: - return 3; - default: - CHECK(false) << "missing activation type"; - } -#elif MXNET_USE_MKLDNN == 1 - // \sa ActivationGradComputeExCPU switch (act_type) { case kReLU: return 2; - case kSigmoid: - case kTanh: case kSoftReLU: case kSoftSign: - return 3; - default: - CHECK(false) << "missing activation type"; - } -#else - // check activation-inl.h \sa ActivationGradComputeImpl - switch (act_type) { - case kReLU: - case kSigmoid: case kTanh: - case kSoftReLU: - return 2; - case kSoftSign: + case kSigmoid: return 3; default: CHECK(false) << "missing activation type"; } -#endif // unreachable return -1; } + } // namespace activation DMLC_REGISTER_PARAMETER(ActivationParam); @@ -99,52 +72,21 @@ struct ActivationGrad { const NodeAttrs& attrs = n->attrs; using namespace activation; int act_type = dmlc::get(attrs.parsed).act_type; -#if MXNET_USE_CUDNN == 1 // for ReLU, no need to pass input data. This enables inplace optimization during the // forward pass. // check activation.cu \sa ActivationGradCompute switch (act_type) { case kReLU: - case kSoftReLU: break; - case kSoftSign: - case kTanh: - case kSigmoid: - heads.push_back(n->inputs[activation::kData]); - break; - default: - CHECK(false) << "missing activation type"; - } -#elif MXNET_USE_MKLDNN == 1 - // \sa ActivationGradComputeExCPU - switch (act_type) { - case kReLU: - break; - case kSoftSign: - case kTanh: case kSoftReLU: - case kSigmoid: - heads.push_back(n->inputs[activation::kData]); - break; - default: - CHECK(false) << "missing activation type"; - } - -#else - // check activation-inl.h \sa ActivationGradComputeImpl - switch (act_type) { case kSoftSign: - heads.push_back(n->inputs[activation::kData]); - break; - case kReLU: case kTanh: - case kSoftReLU: case kSigmoid: + heads.push_back(n->inputs[activation::kData]); break; default: CHECK(false) << "missing activation type"; } -#endif return MakeGradNode(op_name, n, heads, n->attrs.dict); } }; @@ -177,9 +119,9 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs, MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); // XXX: for y = relu(x), y is passed as "in_data" to Backward() const bool relu = param.act_type == activation::kReLU; - MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0], + MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0], outputs[0]); - MKLDNN_OPCHECK_RUN(ActivationGradCompute, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(ActivationGradCompute, attrs, ctx, inputs, req, outputs); return; } FallBackCompute(ActivationGradComputeImpl, attrs, ctx, inputs, req, outputs); @@ -209,7 +151,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param), dispatch_mode, in_attrs, out_attrs); } -#endif +#endif // MXNET_USE_MKLDNN == 1 MXNET_OPERATOR_REGISTER_UNARY(Activation) diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index 13330e3a95f3..ec7db844b100 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -85,20 +85,20 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, // both SoftReLU and SoftSign not supported by CUDNN yet if (act_type == activation::kSoftReLU) { ActivationBackward( - ctx, inputs[0], inputs[1], req[0], outputs[0]); + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); } else if (act_type == activation::kSoftSign) { ActivationBackward( - ctx, inputs[0], inputs[2], req[0], outputs[0]); + ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]); } else if (act_type == activation::kReLU) { - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { // XXX: for y = relu(x), y is passed as "in_data" to Backward() - get_cudnn_op(param).Backward(ctx, inputs[0], inputs[1], - inputs[1], req[0], outputs[0]); + get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(1), + inputs.at(1), req[0], outputs[0]); }); } else { - MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - get_cudnn_op(param).Backward(ctx, inputs[0], inputs[2], - inputs[1], req[0], outputs[0]); + MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { + get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(2), + inputs.at(1), req[0], outputs[0]); }); } } diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc index 05863d7a5f8d..bba8a3ec5722 100644 --- a/tests/cpp/operator/activation_perf.cc +++ b/tests/cpp/operator/activation_perf.cc @@ -53,6 +53,12 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) { runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor::ArgsWithOpName( activation_args, "Activation", "_backward_Activation"), 1); } + for (const string& activation : activations) { + kwargs_t activation_args = {{"act_type", activation}}; + test::op::CoreOperatorRunner runner; + runner.RunBidirectional(true, { shape }, test::op::CoreOpExecutor::ArgsWithOpName( + activation_args, "Activation", "_backward_Activation"), 1); + } } /*!