diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 97949ffbc81e..50cfc2f713f4 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -94,16 +94,14 @@ inline static bool SoftmaxGradStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - const SoftmaxParam& param = nnvm::get(attrs.parsed); - if (param.use_length.value() || softmax_has_dtype_override(attrs)) { - auto& out_stype = out_attrs->at(0); - return storage_type_assign(&out_stype, kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); + bool support = true; + if (softmax_use_length(attrs) || softmax_has_dtype_override(attrs)) { + support = false; } - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, - out_attrs); + + CHECK_EQ(in_attrs->size(), SoftmaxGradOpNumInputs(attrs)); + CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U); + return MKLDNNStorageType(attrs, dev_mask, support, dispatch_mode, in_attrs, out_attrs); } #endif