diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index b663e7de4698..94169255210f 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -308,16 +308,16 @@ struct SoftmaxParam : public dmlc::Parameter { } }; -inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { +static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); const SoftmaxParam& param = nnvm::get(attrs.parsed); - int arg_dtype = param.dtype.has_value()?param.dtype.value():-1, - in_dtype = (*in_attrs)[0], - out_dtype = (*out_attrs)[0]; + int arg_dtype = param.dtype.has_value() ? param.dtype.value() : -1; + int in_dtype = (*in_attrs)[0]; + int out_dtype = (*out_attrs)[0]; if (out_dtype != -1 && in_dtype != -1) { TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); @@ -342,20 +342,61 @@ inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, } } -inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 3); - CHECK_EQ(out_attrs->size(), 1); +static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + return param.dtype.has_value() && param.dtype.value() != -1; +} + +static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + if (softmax_has_dtype_override(attrs)) { + return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); + } else { + return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs); + } +} + +static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + if (softmax_has_dtype_override(attrs)) { + int in_dtype = (*in_attrs)[1]; + int out_dtype = (*in_attrs)[2]; + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); - int in_dtype = (*in_attrs)[1], - out_dtype = (*in_attrs)[2]; - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; + } else { + int out_dtype = (*in_attrs)[1]; + TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); - return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; + return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; + } } +static inline std::vector > +SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { + if (softmax_has_dtype_override(attrs)) { + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + } else { + return std::vector >{{0, 0}, {1, 0}}; + } +} + +struct SoftmaxFGradient { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + if (softmax_has_dtype_override(n->attrs)) { + return ElemwiseGradUseInOut {op_name}(n, ograds); + } else { + return ElemwiseGradUseOut {op_name}(n, ograds); + } + } +}; + template void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -401,17 +442,20 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - MXNET_REAL_ACC_TYPE_SWITCH(inputs[2].type_flag_, OType, AType, { + + int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; + + MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { if (shape.ndim() == 2) { SoftmaxGrad( - ctx.get_stream(), inputs[2].dptr(), + ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), shape.get<2>(), axis, static_cast(temperature)); } else { SoftmaxGrad( - ctx.get_stream(), inputs[2].dptr(), + ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), shape.get<3>(), axis, static_cast(temperature)); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 1d6cef58263c..f300d56dcda2 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -102,7 +102,7 @@ Example:: .set_attr("FComputeEx", SoftmaxComputeExCPU) .set_attr("FInferStorageType", SoftmaxStorageType) #endif -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_softmax"}) +.set_attr("FGradient", SoftmaxFGradient{"_backward_softmax"}) .set_attr("FInferType", SoftmaxOpType) .set_num_inputs(1) .set_num_outputs(1) @@ -121,12 +121,9 @@ NNVM_REGISTER_OP(_backward_softmax) [](const NodeAttrs& attrs) { return std::vector{"ograd", "data", "output"}; }) -.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}, {1, 0}, {2, 0}}; - }) +.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) .add_argument("ograd", "NDArray-or-Symbol", "gradient of output") .add_argument("data", "NDArray-or-Symbol", "input") .add_argument("output", "NDArray-or-Symbol", "output") @@ -165,7 +162,7 @@ Example:: return std::vector{"output"}; }) .set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_softmin"}) +.set_attr("FGradient", SoftmaxFGradient{"_backward_softmin"}) .set_attr("FInferType", SoftmaxOpType) .set_num_inputs(1) .set_num_outputs(1) @@ -184,12 +181,9 @@ NNVM_REGISTER_OP(_backward_softmin) [](const NodeAttrs& attrs) { return std::vector{"ograd", "data", "output"}; }) -.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}, {1, 0}, {2, 0}}; - }) +.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) .add_argument("ograd", "NDArray-or-Symbol", "gradient of output") .add_argument("data", "NDArray-or-Symbol", "input") .add_argument("output", "NDArray-or-Symbol", "output") @@ -216,7 +210,7 @@ Examples:: )code") .set_attr_parser(ParamParser) .set_attr("FCompute", SoftmaxCompute) -.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_log_softmax"}) +.set_attr("FGradient", SoftmaxFGradient{"_backward_log_softmax"}) .set_attr("FInferType", SoftmaxOpType) .set_num_inputs(1) .set_num_outputs(1) @@ -235,12 +229,9 @@ NNVM_REGISTER_OP(_backward_log_softmax) [](const NodeAttrs& attrs) { return std::vector{"ograd", "data", "output"}; }) -.set_attr("FInferShape", ElemwiseShape<3, 1>) +.set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}, {1, 0}, {2, 0}}; - }) +.set_attr("FInplaceOption", SoftmaxGradOpInplaceOption) .add_argument("ograd", "NDArray-or-Symbol", "gradient of output") .add_argument("data", "NDArray-or-Symbol", "input") .add_argument("output", "NDArray-or-Symbol", "output")