Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
simplify infer type
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 13, 2019
1 parent 704b0e7 commit bfced75
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,30 +316,12 @@ static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);

int arg_dtype = param.dtype.has_value() ? param.dtype.value() : -1;
int in_dtype = (*in_attrs)[0];
int out_dtype = (*out_attrs)[0];

if (out_dtype != -1 && in_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
return true;
} else if (in_dtype != -1) {
if (arg_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
}
return true;
} else if (out_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
if (softmax_has_dtype_override(attrs)) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
type_assign(&(*in_attrs)[0], (*out_attrs)[0]);
return true;
} else {
if (arg_dtype != -1) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, arg_dtype);
}
return false;
return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs);
}
}

Expand Down

0 comments on commit bfced75

Please sign in to comment.