diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 01d1b8c6bb9e..fc2bb96e865f 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -316,30 +316,12 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1); const SoftmaxParam& param = nnvm::get(attrs.parsed); - int arg_dtype = param.dtype.has_value() ? param.dtype.value() : -1; - int in_dtype = (*in_attrs)[0]; - int out_dtype = (*out_attrs)[0]; - - if (out_dtype != -1 && in_dtype != -1) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); - return true; - } else if (in_dtype != -1) { - if (arg_dtype != -1) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); - } else { - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); - } - return true; - } else if (out_dtype != -1) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); + if (softmax_has_dtype_override(attrs)) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); + type_assign(&(*in_attrs)[0], (*out_attrs)[0]); return true; } else { - if (arg_dtype != -1) { - TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype); - } - return false; + return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs); } }