diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 096d87416081..8098af23f2e9 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -64,7 +64,8 @@ struct log_softmax_fwd { }; -template +template inline void Softmax(Stream *s, DType *in, OType *out, Shape shape, int axis, const DType temperature) { index_t M = shape[axis]; @@ -310,9 +311,9 @@ struct SoftmaxParam : public dmlc::Parameter { } }; -static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { +static inline int sofmtax_dtype_param(const nnvm::NodeAttrs &attrs) { const SoftmaxParam& param = nnvm::get(attrs.parsed); - return param.dtype.has_value() && param.dtype.value() != -1; + return param.dtype.has_value() ? param.dtype.value(): -1; } static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, @@ -322,7 +323,7 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); const SoftmaxParam& param = nnvm::get(attrs.parsed); - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); type_assign(&(*in_attrs)[0], (*out_attrs)[0]); return true; @@ -334,7 +335,7 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); } else { return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs); @@ -345,7 +346,7 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { CHECK_EQ(out_attrs->size(), 1); - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { CHECK_EQ(in_attrs->size(), 3); int in_dtype = (*in_attrs)[1]; int out_dtype = (*in_attrs)[2]; @@ -365,7 +366,7 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, static inline std::vector > SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { return std::vector >{{0, 0}, {1, 0}, {2, 0}}; } else { return std::vector >{{0, 0}, {1, 0}}; @@ -373,11 +374,11 @@ SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { } static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) { - return softmax_has_dtype_override(attrs) ? 3 : 2; + return sofmtax_dtype_param(attrs) != -1 ? 3 : 2; } static inline std::vector SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { + if (sofmtax_dtype_param(attrs) != -1) { return std::vector{"ograd", "data", "output"}; } else { return std::vector{"ograd", "output"}; @@ -388,7 +389,7 @@ struct SoftmaxFGradient { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, const std::vector& ograds) const { - if (softmax_has_dtype_override(n->attrs)) { + if (sofmtax_dtype_param(n->attrs) != -1) { return ElemwiseGradUseInOut {op_name}(n, ograds); } else { return ElemwiseGradUseOut {op_name}(n, ograds); @@ -410,19 +411,25 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, const double temperature = param.temperature.has_value() ? param.temperature.value() : 1.0; mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, { - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { - if (shape.ndim() == 2) { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<2>(), axis, - static_cast(temperature)); - } else { - Softmax( - ctx.get_stream(), inputs[0].dptr(), - outputs[0].dptr(), shape.get<3>(), axis, - static_cast(temperature)); - } + + int atype_flag_ = sofmtax_dtype_param(attrs); + atype_flag_ = atype_flag_ != -1 ? atype_flag_ : inputs[0].type_flag_; + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(atype_flag_, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, { + if (shape.ndim() == 2) { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<2>(), axis, + static_cast(temperature)); + } else { + Softmax( + ctx.get_stream(), inputs[0].dptr(), + outputs[0].dptr(), shape.get<3>(), axis, + static_cast(temperature)); + } + }); }); }); } @@ -442,23 +449,28 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, param.temperature.value() : 1.0; mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); - int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; - - MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (shape.ndim() == 2) { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); - } - }); + int out_idx = sofmtax_dtype_param(attrs) != -1 ? 2 : 1; + + int atype_flag_ = sofmtax_dtype_param(attrs); + atype_flag_ = atype_flag_ != -1 ? atype_flag_ : inputs[0].type_flag_; + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, OType, { + MSHADOW_REAL_TYPE_SWITCH(atype_flag_, AType, { + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } + }); + }); }); }); }