diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index cf44da699156..faac69269f00 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -128,29 +128,33 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, if (n_out != -1) out_size = static_cast(n_out); - auto deduce = [&](std::vector *vec, size_t size, const char *name) { + CHECK_LE(in_size, in_attrs->size()); + CHECK_LE(out_size, out_attrs->size()); + auto deduce = [&](const std::vector& vec, size_t size, const char *name) { for (size_t i = 0; i < size; ++i) { - CHECK(assign(&dattr, (*vec)[i])) + CHECK(assign(&dattr, vec.at(i))) << "Incompatible attr in node " << attrs.name << " at " << i << "-th " << name << ": " << "expected " << attr_string(dattr) - << ", got " << attr_string((*vec)[i]); + << ", got " << attr_string(vec.at(i)); } }; - deduce(in_attrs, in_size, "input"); - if (reverse_infer) deduce(out_attrs, out_size, "output"); + deduce(*in_attrs, in_size, "input"); + if (reverse_infer) + deduce(*out_attrs, out_size, "output"); auto write = [&](std::vector *vec, size_t size, const char *name) { for (size_t i = 0; i < size; ++i) { - CHECK(assign(&(*vec)[i], dattr)) + CHECK(assign(&(vec->at(i)), dattr)) << "Incompatible attr in node " << attrs.name << " at " << i << "-th " << name << ": " << "expected " << attr_string(dattr) - << ", got " << attr_string((*vec)[i]); + << ", got " << attr_string(vec->at(i)); } }; write(in_attrs, in_size, "input"); write(out_attrs, out_size, "output"); - if (is_none(dattr)) return false; + if (is_none(dattr)) + return false; return true; } diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 2705177f951d..ab7d4b7c17b5 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -48,6 +48,9 @@ enum ActivationOpInputs {kData}; enum ActivationOpOutputs {kOut}; enum ActivationOpResource {kTempSpace}; enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign}; + +// Get the number of inputs to the gradient depending on the activation type +int ActivationGradNumInputs(int act_type); } // activation struct ActivationParam : public dmlc::Parameter { @@ -199,13 +202,8 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ActivationParam& param = nnvm::get(attrs.parsed); -#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) - bool relu = param.act_type == activation::kReLU; - CHECK_EQ(inputs.size(), relu ? 2U : 3U); -#else - bool softsign = param.act_type == activation::kSoftSign; - CHECK_EQ(inputs.size(), softsign ? 3U : 2U); -#endif + const int act_type = param.act_type; + CHECK_EQ(inputs.size(), activation::ActivationGradNumInputs(act_type)); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); ActivationGradComputeImpl(attrs, ctx, inputs, req, outputs); diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index ba44ebd4ed4d..147c92539a61 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -37,6 +37,54 @@ namespace mxnet { namespace op { +namespace activation { + +int ActivationGradNumInputs(int act_type) { +#if (MXNET_USE_CUDNN == 1) + // check activation.cu \sa ActivationGradCompute + switch (act_type) { + case kReLU: + case kSoftReLU: + return 2; + case kSoftSign: + case kTanh: + case kSigmoid: + return 3; + default: + CHECK(false) << "missing activation type"; + } +#elif (MXNET_USE_MKLDNN == 1) + // \sa ActivationGradComputeExCPU + switch (act_type) { + case kReLU: + return 2; + case kSigmoid: + case kTanh: + case kSoftReLU: + case kSoftSign: + return 3; + default: + CHECK(false) << "missing activation type"; + } +#else + // check activation-inl.h \sa ActivationGradComputeImpl + switch (act_type) { + case kReLU: + case kSigmoid: + case kTanh: + case kSoftReLU: + return 2; + case kSoftSign: + return 3; + default: + CHECK(false) << "missing activation type"; + } +#endif + // unreachable + return -1; +} +} // namespace activation + DMLC_REGISTER_PARAMETER(ActivationParam); // This will determine the order of the inputs for backward computation. @@ -48,18 +96,52 @@ struct ActivationGrad { heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0}); const NodeAttrs& attrs = n->attrs; + using namespace activation; int act_type = dmlc::get(attrs.parsed).act_type; - if (act_type == activation::kSoftSign) { - // for softsign need the inputs to compute the activation. - heads.push_back(n->inputs[activation::kData]); - } - -#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1) +#if (MXNET_USE_CUDNN == 1) // for ReLU, no need to pass input data. This enables inplace optimization during the // forward pass. - if (act_type != activation::kReLU && - act_type != activation::kSoftSign) { - heads.push_back(n->inputs[activation::kData]); + // check activation.cu \sa ActivationGradCompute + switch (act_type) { + case kReLU: + case kSoftReLU: + break; + case kSoftSign: + case kTanh: + case kSigmoid: + heads.push_back(n->inputs[activation::kData]); + break; + default: + CHECK(false) << "missing activation type"; + } +#elif (MXNET_USE_MKLDNN == 1) + // \sa ActivationGradComputeExCPU + switch (act_type) { + case kReLU: + break; + case kSoftSign: + case kTanh: + case kSoftReLU: + case kSigmoid: + heads.push_back(n->inputs[activation::kData]); + break; + default: + CHECK(false) << "missing activation type"; + } + +#else + // check activation-inl.h \sa ActivationGradComputeImpl + switch (act_type) { + case kSoftSign: + heads.push_back(n->inputs[activation::kData]); + break; + case kReLU: + case kTanh: + case kSoftReLU: + case kSigmoid: + break; + default: + CHECK(false) << "missing activation type"; } #endif return MakeGradNode(op_name, n, heads, n->attrs.dict); @@ -133,6 +215,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, } #endif + MXNET_OPERATOR_REGISTER_UNARY(Activation) .describe(R"code(Applies an activation function element-wise to the input. @@ -163,18 +246,16 @@ The following activation functions are supported: NNVM_REGISTER_OP(_backward_Activation) .set_num_inputs([](const nnvm::NodeAttrs& attrs) { - int act_type = dmlc::get(attrs.parsed).act_type; - // for ReLU activation, the backward pass only needs ograd and output - if (act_type == activation::kReLU) return 2; - return 3; - }) + const int act_type = dmlc::get(attrs.parsed).act_type; + return activation::ActivationGradNumInputs(act_type); +}) .set_num_outputs(1) .set_attr("TIsBackward", true) #if MXNET_USE_MKLDNN == 1 .set_attr("FInferStorageType", BackwardActStorageType) #endif -.set_attr("FInferShape", ElemwiseShape<3, 1>) -.set_attr("FInferType", ElemwiseType<3, 1>) +.set_attr("FInferShape", ElemwiseShape<-1,1>) +.set_attr("FInferType", ElemwiseType<-1, 1>) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index 8892cc34f710..94d358e416d0 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -88,10 +88,15 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, } else if (param.act_type == activation::kSoftSign) { ActivationBackward( ctx, inputs[0], inputs[2], req[0], outputs[0]); - } else { + } else if (parma.act_type == activation::kReLU) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { // XXX: for y = relu(x), y is passed as "in_data" to Backward() - get_cudnn_op(param).Backward(ctx, inputs[0], relu ? inputs[1] : inputs[2], + get_cudnn_op(param).Backward(ctx, inputs[0], inputs[1], + inputs[1], req[0], outputs[0]); + }); + } else { + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + get_cudnn_op(param).Backward(ctx, inputs[0], inputs[2], inputs[1], req[0], outputs[0]); }); } diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc index 94366fd869d3..05863d7a5f8d 100644 --- a/tests/cpp/operator/activation_perf.cc +++ b/tests/cpp/operator/activation_perf.cc @@ -47,7 +47,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) { "softrelu", "softsign" }; - for(const string& activation : activations) { + for (const string& activation : activations) { kwargs_t activation_args = {{"act_type", activation}}; test::op::CoreOperatorRunner runner; runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor::ArgsWithOpName(