@@ -83,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
83
83
std::vector<int > *in_type, std::vector<int > *out_type) {
84
84
using namespace mshadow ;
85
85
CHECK_GE (in_type->size (), 1U );
86
- const int dtype = (*in_type)[0 ];
87
- CHECK_NE (dtype, -1 ) << " First input must have specified type" ;
86
+ const size_t n_out = 4 ;
88
87
// For float16 input type beta, gamma, mean, and average are stored in float32.
89
88
// For other input types, these parameters have the same type as input
90
89
// NOTE: This requirement is from cuDNN (v. 4 and 5)
91
90
int dtype_param;
92
- MSHADOW_REAL_TYPE_SWITCH_EX (dtype, DTypeX, AccRealX, {
91
+ int dtype = (*in_type)[0 ];
92
+
93
+ if (type_is_none (dtype)) {
94
+ // Input type is undefined, we try backward inference
95
+ if (out_type->size () == 0 || type_is_none ((*out_type)[0 ])) {
96
+ // Neither the input nor the output are defined,
97
+ // types cannot be infered for this op
98
+ return false ;
99
+ } else {
100
+ // Input type is undefined but output type is: backward inference
101
+ dtype = (*out_type)[0 ];
102
+ (*in_type)[0 ] = dtype;
103
+ MSHADOW_REAL_TYPE_SWITCH_EX (dtype, DTypeX, AccRealX, {
104
+ dtype_param = mshadow::DataType<AccRealX>::kFlag ; });
105
+ }
106
+ } else {
107
+ // Input type is defined but output type is not: forward inference
108
+ MSHADOW_REAL_TYPE_SWITCH_EX (dtype, DTypeX, AccRealX, {
93
109
dtype_param = mshadow::DataType<AccRealX>::kFlag ; });
110
+ out_type->clear ();
111
+ out_type->push_back (dtype);
112
+ for (size_t i = 1 ; i < n_out; ++i) {
113
+ out_type->push_back (dtype_param);
114
+ }
115
+ }
94
116
std::vector<std::string> args{" data" , " gamma" , " beta" , " mean" , " var" };
95
117
CHECK_LE (in_type->size (), args.size ());
96
118
for (size_t i = 1 ; i < in_type->size (); ++i) {
@@ -100,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
100
122
UNIFORM_TYPE_CHECK ((*in_type)[i], dtype_param, args[i]);
101
123
}
102
124
}
103
- const size_t n_out = 4 ;
104
- out_type->clear ();
105
- out_type->push_back (dtype);
106
- for (size_t i = 1 ; i < n_out; ++i) {
107
- out_type->push_back (dtype_param);
108
- }
109
125
return true ;
110
126
}
111
127
0 commit comments