diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 91284f205f86..90950bc9e92e 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -26,6 +26,7 @@ #define MXNET_OPERATOR_NN_SOFTMAX_INL_H_ #include +#include #include #include @@ -343,7 +344,9 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { + CHECK_EQ(out_attrs->size(), 1); if (softmax_has_dtype_override(attrs)) { + CHECK_EQ(in_attrs->size(), 3); int in_dtype = (*in_attrs)[1]; int out_dtype = (*in_attrs)[2]; TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); @@ -351,6 +354,7 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; } else { + CHECK_EQ(in_attrs->size(), 2); int out_dtype = (*in_attrs)[1]; TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); @@ -368,6 +372,18 @@ SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { } } +static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) { + return softmax_has_dtype_override(attrs) ? 3 : 2; +} + +static inline std::vector SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) { + if (softmax_has_dtype_override(attrs)) { + return std::vector{"ograd", "data", "output"}; + } else { + return std::vector{"ograd", "output"}; + } +} + struct SoftmaxFGradient { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index f300d56dcda2..c88f738c356d 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -115,18 +115,13 @@ Example:: .add_arguments(SoftmaxParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_softmax) -.set_num_inputs(3) +.set_num_inputs(SoftmaxGradOpNumInputs) .set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"ograd", "data", "output"}; - }) +.set_attr("FListInputNames", SoftmaxGradOpInputNames) .set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) .set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) -.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") -.add_argument("data", "NDArray-or-Symbol", "input") -.add_argument("output", "NDArray-or-Symbol", "output") +.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); @@ -175,18 +170,13 @@ Example:: .add_arguments(SoftmaxParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_softmin) -.set_num_inputs(3) +.set_num_inputs(SoftmaxGradOpNumInputs) .set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"ograd", "data", "output"}; - }) +.set_attr("FListInputNames", SoftmaxGradOpInputNames) .set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) .set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) -.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") -.add_argument("data", "NDArray-or-Symbol", "input") -.add_argument("output", "NDArray-or-Symbol", "output") +.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute); @@ -223,18 +213,13 @@ Examples:: .add_arguments(SoftmaxParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_log_softmax) -.set_num_inputs(3) +.set_num_inputs(SoftmaxGradOpNumInputs) .set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"ograd", "data", "output"}; - }) +.set_attr("FListInputNames", SoftmaxGradOpInputNames) .set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) .set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) -.add_argument("ograd", "NDArray-or-Symbol", "gradient of output") -.add_argument("data", "NDArray-or-Symbol", "input") -.add_argument("output", "NDArray-or-Symbol", "output") +.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxGradCompute);