Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add backward Type inference to main NN operators #18378

Merged
merged 2 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
const int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
const size_t n_out = 4;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
int dtype = (*in_type)[0];

if (type_is_none(dtype)) {
Kh4L marked this conversation as resolved.
Show resolved Hide resolved
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
Kh4L marked this conversation as resolved.
Show resolved Hide resolved
(*in_type)[0] = dtype;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
}
} else {
// Input type is defined but output type is not: forward inference
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
Kh4L marked this conversation as resolved.
Show resolved Hide resolved
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
}
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
Expand All @@ -101,12 +123,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
const size_t n_out = 4;
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
return true;
}

Expand Down
33 changes: 24 additions & 9 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
const int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
const size_t n_out = 3;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
int dtype = (*in_type)[0];
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
(*in_type)[0] = dtype;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
}
} else {
// Input type is defined but output type is not: forward inference
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
}
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
Expand All @@ -369,12 +390,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
const size_t n_out = 3;
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
return true;
}

Expand Down
13 changes: 10 additions & 3 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,23 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
}
} else {
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
18 changes: 15 additions & 3 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,28 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
const DeconvolutionParam& param_ = nnvm::get<DeconvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
}
} else {
// Input type is defined but output type is not: forward inference
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
18 changes: 15 additions & 3 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,28 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_type) {
CHECK_EQ(in_type->size(), 2U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
// Input type is undefined, we try backward inference
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
// Neither the input nor the output are defined,
// types cannot be infered for this op
return false;
} else {
// Input type is undefined but output type is: backward inference
dtype = (*out_type)[0];
}
} else {
// Input type is defined but output type is not: forward inference
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down