Skip to content

Commit 0bdf360

Browse files
Kh4LAntiZpvoh
authored andcommitted
Add backward Type inference to main NN operators (apache#18378)
* Add backward Type inference to main DNN operators Signed-off-by: Serge Panev <[email protected]> * Add comments Signed-off-by: Serge Panev <[email protected]>
1 parent 821c2f4 commit 0bdf360

File tree

5 files changed

+89
-27
lines changed

5 files changed

+89
-27
lines changed

src/operator/contrib/batch_norm_relu.cc

+25-9
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
8383
std::vector<int> *in_type, std::vector<int> *out_type) {
8484
using namespace mshadow;
8585
CHECK_GE(in_type->size(), 1U);
86-
const int dtype = (*in_type)[0];
87-
CHECK_NE(dtype, -1) << "First input must have specified type";
86+
const size_t n_out = 4;
8887
// For float16 input type beta, gamma, mean, and average are stored in float32.
8988
// For other input types, these parameters have the same type as input
9089
// NOTE: This requirement is from cuDNN (v. 4 and 5)
9190
int dtype_param;
92-
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
91+
int dtype = (*in_type)[0];
92+
93+
if (type_is_none(dtype)) {
94+
// Input type is undefined, we try backward inference
95+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
96+
// Neither the input nor the output are defined,
97+
// types cannot be infered for this op
98+
return false;
99+
} else {
100+
// Input type is undefined but output type is: backward inference
101+
dtype = (*out_type)[0];
102+
(*in_type)[0] = dtype;
103+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
104+
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
105+
}
106+
} else {
107+
// Input type is defined but output type is not: forward inference
108+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
93109
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
110+
out_type->clear();
111+
out_type->push_back(dtype);
112+
for (size_t i = 1; i < n_out; ++i) {
113+
out_type->push_back(dtype_param);
114+
}
115+
}
94116
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
95117
CHECK_LE(in_type->size(), args.size());
96118
for (size_t i = 1; i < in_type->size(); ++i) {
@@ -100,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
100122
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
101123
}
102124
}
103-
const size_t n_out = 4;
104-
out_type->clear();
105-
out_type->push_back(dtype);
106-
for (size_t i = 1; i < n_out; ++i) {
107-
out_type->push_back(dtype_param);
108-
}
109125
return true;
110126
}
111127

src/operator/nn/batch_norm.cc

+24-9
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
392392
std::vector<int> *in_type, std::vector<int> *out_type) {
393393
using namespace mshadow;
394394
CHECK_GE(in_type->size(), 1U);
395-
const int dtype = (*in_type)[0];
396-
CHECK_NE(dtype, -1) << "First input must have specified type";
395+
const size_t n_out = 3;
397396
// For float16 input type beta, gamma, mean, and average are stored in float32.
398397
// For other input types, these parameters have the same type as input
399398
// NOTE: This requirement is from cuDNN (v. 4 and 5)
400399
int dtype_param;
401-
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
400+
int dtype = (*in_type)[0];
401+
if (type_is_none(dtype)) {
402+
// Input type is undefined, we try backward inference
403+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
404+
// Neither the input nor the output are defined,
405+
// types cannot be infered for this op
406+
return false;
407+
} else {
408+
// Input type is undefined but output type is: backward inference
409+
dtype = (*out_type)[0];
410+
(*in_type)[0] = dtype;
411+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
412+
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
413+
}
414+
} else {
415+
// Input type is defined but output type is not: forward inference
416+
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
402417
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
418+
out_type->clear();
419+
out_type->push_back(dtype);
420+
for (size_t i = 1; i < n_out; ++i) {
421+
out_type->push_back(dtype_param);
422+
}
423+
}
403424
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
404425
CHECK_LE(in_type->size(), args.size());
405426
for (size_t i = 1; i < in_type->size(); ++i) {
@@ -409,12 +430,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
409430
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
410431
}
411432
}
412-
const size_t n_out = 3;
413-
out_type->clear();
414-
out_type->push_back(dtype);
415-
for (size_t i = 1; i < n_out; ++i) {
416-
out_type->push_back(dtype_param);
417-
}
418433
return true;
419434
}
420435

src/operator/nn/convolution.cc

+10-3
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,23 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
285285
const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
286286
CHECK_GE(in_type->size(), 1U);
287287
int dtype = (*in_type)[0];
288-
CHECK_NE(dtype, -1) << "First input must have specified type";
288+
if (type_is_none(dtype)) {
289+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
290+
return false;
291+
} else {
292+
dtype = (*out_type)[0];
293+
}
294+
} else {
295+
out_type->clear();
296+
out_type->push_back(dtype);
297+
}
289298
for (size_t i = 0; i < in_type->size(); ++i) {
290299
if ((*in_type)[i] == -1) {
291300
(*in_type)[i] = dtype;
292301
} else {
293302
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
294303
}
295304
}
296-
out_type->clear();
297-
out_type->push_back(dtype);
298305
return true;
299306
}
300307

src/operator/nn/deconvolution.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,28 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
332332
const DeconvolutionParam& param_ = nnvm::get<DeconvolutionParam>(attrs.parsed);
333333
CHECK_GE(in_type->size(), 1U);
334334
int dtype = (*in_type)[0];
335-
CHECK_NE(dtype, -1) << "First input must have specified type";
335+
if (type_is_none(dtype)) {
336+
// Input type is undefined, we try backward inference
337+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
338+
// Neither the input nor the output are defined,
339+
// types cannot be infered for this op
340+
return false;
341+
} else {
342+
// Input type is undefined but output type is: backward inference
343+
dtype = (*out_type)[0];
344+
}
345+
} else {
346+
// Input type is defined but output type is not: forward inference
347+
out_type->clear();
348+
out_type->push_back(dtype);
349+
}
336350
for (size_t i = 0; i < in_type->size(); ++i) {
337351
if ((*in_type)[i] == -1) {
338352
(*in_type)[i] = dtype;
339353
} else {
340354
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
341355
}
342356
}
343-
out_type->clear();
344-
out_type->push_back(dtype);
345357
return true;
346358
}
347359

src/operator/softmax_output.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,28 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
6666
std::vector<int> *out_type) {
6767
CHECK_EQ(in_type->size(), 2U);
6868
int dtype = (*in_type)[0];
69-
CHECK_NE(dtype, -1) << "First input must have specified type";
69+
if (type_is_none(dtype)) {
70+
// Input type is undefined, we try backward inference
71+
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
72+
// Neither the input nor the output are defined,
73+
// types cannot be infered for this op
74+
return false;
75+
} else {
76+
// Input type is undefined but output type is: backward inference
77+
dtype = (*out_type)[0];
78+
}
79+
} else {
80+
// Input type is defined but output type is not: forward inference
81+
out_type->clear();
82+
out_type->push_back(dtype);
83+
}
7084
for (size_t i = 0; i < in_type->size(); ++i) {
7185
if ((*in_type)[i] == -1) {
7286
(*in_type)[i] = dtype;
7387
} else {
7488
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
7589
}
7690
}
77-
out_type->clear();
78-
out_type->push_back(dtype);
7991
return true;
8092
}
8193

0 commit comments

Comments
 (0)