Skip to content

Commit

Permalink
Fix storage type infer of softmax backward (apache#17576)
Browse files Browse the repository at this point in the history
* fix storage type of softmax backward

* remove unused variable
  • Loading branch information
TaoLv authored and Ubuntu committed Feb 19, 2020
1 parent 6a6717e commit 06560c9
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,14 @@ inline static bool SoftmaxGradStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (param.use_length.value() || softmax_has_dtype_override(attrs)) {
auto& out_stype = out_attrs->at(0);
return storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
bool support = true;
if (softmax_use_length(attrs) || softmax_has_dtype_override(attrs)) {
support = false;
}
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);

CHECK_EQ(in_attrs->size(), SoftmaxGradOpNumInputs(attrs));
CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);
return MKLDNNStorageType(attrs, dev_mask, support, dispatch_mode, in_attrs, out_attrs);
}
#endif

Expand Down

0 comments on commit 06560c9

Please sign in to comment.