diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index 14452cc96729..24ec7d14a191 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -84,14 +84,31 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 4; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -101,12 +118,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 4; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 815288cfe554..a267479b4cde 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -352,14 +352,31 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 3; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -369,12 +386,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 3; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 8ff5ea75d5f7..3ebb67ad0aa0 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -285,7 +285,16 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -293,8 +302,6 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index cd22aced0d03..0b0fecc1849f 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -332,7 +332,16 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, const DeconvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -340,8 +349,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 13bb647f9d43..da02fcfabbf5 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -66,7 +66,16 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, std::vector *out_type) { CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -74,8 +83,6 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; }