@@ -172,9 +172,32 @@ dnnl::memory::dims str2dims(const std::string& str_shape,
172172 return out_dims;
173173}
174174
175- std::string get_optimal_layout_for_conv (std::string weight_shape,
176- std::string out_shape, std::string paddings,
177- std::string strides, std::string dilates, std::string G) {
175+ void check_shapes (const std::vector<std::string> shapes) {
176+ std::regex valid_pat (" (\\ d*)(,(\\ d*))*" );
177+ bool checked = std::regex_match (shapes[0 ], valid_pat);
178+ for (size_t i = 1 ; i < shapes.size ()-1 ; i++) {
179+ checked &= std::regex_match (shapes[i], valid_pat);
180+ }
181+ checked &= std::regex_match (shapes[shapes.size ()-1 ], std::regex (" \\ d*" ));
182+ if (!checked) {
183+ LOG (FATAL) << " Invalid input args for query dnnl optimal layout." ;
184+ }
185+ }
186+
187+ void check_layout (bool var, bool ref) {
188+ if (var != ref) {
189+ LOG (FATAL) << " Invalid input layout for query dnnl optimal layout." ;
190+ }
191+ }
192+
193+ std::string get_optimal_layout_for_conv (std::string data_layout, std::string kernel_layout,
194+ std::string weight_shape, std::string out_shape,
195+ std::string paddings, std::string strides,
196+ std::string dilates, std::string G) {
197+ check_layout (std::regex_match (data_layout, std::regex (" NC(D?)(H?)W" )), true );
198+ check_layout (std::regex_match (kernel_layout, std::regex (" (G?)OI(D?)(H?)W" )), true );
199+ check_shapes ({weight_shape, out_shape, paddings, strides, dilates, G});
200+
178201 dnnl::engine eng (dnnl::engine::kind::cpu, 0 );
179202 dnnl::stream s (eng);
180203 using tag = dnnl::memory::format_tag;
@@ -249,10 +272,15 @@ std::string get_optimal_layout_for_conv(std::string weight_shape,
249272 return res;
250273}
251274
252- std::string get_optimal_layout_for_conv_transpose (std::string weight_shape,
253- std::string out_shape, std::string paddings,
254- std::string output_paddings, std::string strides,
255- std::string dilates, std::string G) {
275+ std::string get_optimal_layout_for_conv_transpose (std::string data_layout, std::string kernel_layout,
276+ std::string weight_shape, std::string out_shape,
277+ std::string paddings, std::string output_paddings,
278+ std::string strides, std::string dilates,
279+ std::string G) {
280+ check_layout (std::regex_match (data_layout, std::regex (" NC(D?)(H?)W" )), true );
281+ check_layout (std::regex_match (kernel_layout, std::regex (" (G?)((IO)|(OI))(D?)(H?)W" )), true );
282+ check_shapes ({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});
283+
256284 dnnl::engine eng (dnnl::engine::kind::cpu, 0 );
257285 dnnl::stream s (eng);
258286 using tag = dnnl::memory::format_tag;
@@ -335,13 +363,14 @@ std::string get_optimal_layout_for_conv_transpose(std::string weight_shape,
335363
336364TVM_REGISTER_GLOBAL (" relay.ir.get_optimal_layout_for_conv" )
337365 .set_body([](TVMArgs args, TVMRetValue* rv) {
338- *rv = get_optimal_layout_for_conv (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ]);
366+ *rv = get_optimal_layout_for_conv (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ],
367+ args[6 ], args[7 ]);
339368 });
340369
341370TVM_REGISTER_GLOBAL (" relay.ir.get_optimal_layout_for_conv_transpose" )
342371 .set_body([](TVMArgs args, TVMRetValue* rv) {
343372 *rv = get_optimal_layout_for_conv_transpose (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ],
344- args[5 ], args[6 ]);
373+ args[5 ], args[6 ], args[ 7 ], args[ 8 ] );
345374 });
346375
347376} // namespace contrib
0 commit comments