@@ -146,24 +146,33 @@ std::string md2fmt_tag_str(const dnnl::memory::desc* md) {
146146 return s;
147147}
148148
149- dnnl::memory::dims str2dims (std::string str_shape, int input_size) {
150- std::string str_reg = " (\\ d*)" ;
151- for (int i = 0 ; i < input_size - 1 ; i++) {
152- str_reg.append (" ,(\\ d*)" );
149+ dnnl::memory::dims str2dims (const std::string& str_shape,
150+ bool dilates = false ,
151+ std::string interval = " ," ) {
152+ // Split strings
153+ std::vector<std::string> str_dims;
154+ size_t pos = 0 , start = 0 ;
155+ while ((pos = str_shape.find (interval, start)) != std::string::npos) {
156+ std::string str_dim = str_shape.substr (start, pos - start);
157+ if (pos > start) str_dims.push_back (str_dim);
158+ start = pos + interval.size ();
153159 }
154- std::regex rex (str_reg);
155- std::smatch m;
160+ if (str_shape.size () > start) {
161+ str_dims.push_back (str_shape.substr (start));
162+ }
163+ // transfer string to dims
156164 dnnl::memory::dims out_dims;
157- if (std::regex_search (str_shape, m, rex) ) {
158- std::transform (m .begin () + 1 , m .end (), std::back_inserter (out_dims),
159- [](const std::string& str) { return std::stoi (str); });
165+ if (dilates ) {
166+ std::transform (str_dims .begin (), str_dims .end (), std::back_inserter (out_dims),
167+ [](const std::string& str) { return std::stoi (str) - 1 ; });
160168 } else {
161- LOG (FATAL) << " Unsupported shape for querying optimal dnnl layout: " << str_shape;
169+ std::transform (str_dims.begin (), str_dims.end (), std::back_inserter (out_dims),
170+ [](const std::string& str) { return std::stoi (str); });
162171 }
163172 return out_dims;
164173}
165174
166- std::string get_optimal_layout_for_conv (int input_size, std::string weight_shape,
175+ std::string get_optimal_layout_for_conv (std::string weight_shape,
167176 std::string out_shape, std::string paddings,
168177 std::string strides, std::string dilates, std::string G) {
169178 dnnl::engine eng (dnnl::engine::kind::cpu, 0 );
@@ -172,35 +181,36 @@ std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape
172181 using dt = dnnl::memory::data_type;
173182
174183 dnnl::memory::dim groups = std::stoi (G);
175- dnnl::memory::dims weight_dims_ = str2dims (weight_shape, input_size );
184+ dnnl::memory::dims weight_dims_ = str2dims (weight_shape);
176185 dnnl::memory::dims weight_dims = weight_dims_;
186+
177187 if (groups > 1 ) {
178188 if (weight_dims_.size () == 5 ) {
179- weight_dims = {weight_dims_[ 0 ] * weight_dims_[1 ], weight_dims_[2 ], weight_dims_[3 ],
189+ weight_dims = {groups * weight_dims_[1 ], groups * weight_dims_[2 ], weight_dims_[3 ],
180190 weight_dims_[4 ]};
181191 } else {
182192 weight_dims[1 ] = weight_dims[1 ] * groups;
183193 }
184194 }
185- dnnl::memory::dims out_dims = str2dims (out_shape, input_size);
186- dnnl::memory::dims padding_dims = str2dims (paddings, 2 * (input_size - 2 ));
195+
196+ dnnl::memory::dims out_dims = str2dims (out_shape);
197+ dnnl::memory::dims padding_dims = str2dims (paddings);
187198 dnnl::memory::dims padding_dims_l (padding_dims.begin (),
188199 padding_dims.begin () + padding_dims.size () / 2 );
189200 dnnl::memory::dims padding_dims_r (padding_dims.end () - padding_dims.size () / 2 ,
190201 padding_dims.end ());
191- dnnl::memory::dims strides_dims = str2dims (strides, input_size - 2 );
192- dnnl::memory::dims dilates_dims = str2dims (dilates, input_size - 2 );
202+ dnnl::memory::dims strides_dims = str2dims (strides);
203+ dnnl::memory::dims dilates_dims = str2dims (dilates, true );
193204
194205 dnnl::memory::dims input_dims = out_dims;
195206 input_dims[1 ] = weight_dims[1 ];
196- for (int i = 2 ; i < input_size ; i++) {
207+ for (int i = 2 ; i < out_dims. size () ; i++) {
197208 dnnl::memory::dim K = weight_dims[i];
198209 dnnl::memory::dim S = strides_dims[i - 2 ];
199- dnnl::memory::dim D = dilates_dims[i - 2 ] - 1 ;
210+ dnnl::memory::dim D = dilates_dims[i - 2 ];
200211 dnnl::memory::dim PL = padding_dims_l[i - 2 ];
201212 dnnl::memory::dim PR = padding_dims_r[i - 2 ];
202213 dnnl::memory::dim DK = 1 + (K - 1 ) * (D + 1 );
203- dilates_dims[i - 2 ] = D;
204214 input_dims[i] = out_dims[i] * S - PL - PR + DK - 1 ;
205215 }
206216
@@ -210,6 +220,7 @@ std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape
210220 conv_weights_dims = {groups, out_dims[1 ] / groups, input_dims[1 ] / groups};
211221 conv_weights_dims.insert (conv_weights_dims.end (), weight_dims.begin () + 2 , weight_dims.end ());
212222 }
223+
213224 dnnl::memory::dims conv_dst_dims = out_dims;
214225 dnnl::memory::dims conv_strides = strides_dims;
215226 dnnl::memory::dims conv_dilates = dilates_dims;
@@ -238,7 +249,7 @@ std::string get_optimal_layout_for_conv(int input_size, std::string weight_shape
238249 return res;
239250}
240251
241- std::string get_optimal_layout_for_conv_transpose (int input_size, std::string weight_shape,
252+ std::string get_optimal_layout_for_conv_transpose (std::string weight_shape,
242253 std::string out_shape, std::string paddings,
243254 std::string output_paddings, std::string strides,
244255 std::string dilates, std::string G) {
@@ -248,25 +259,25 @@ std::string get_optimal_layout_for_conv_transpose(int input_size, std::string we
248259 using dt = dnnl::memory::data_type;
249260
250261 dnnl::memory::dim groups = std::stoi (G);
251- dnnl::memory::dims weight_dims_ = str2dims (weight_shape, input_size );
262+ dnnl::memory::dims weight_dims_ = str2dims (weight_shape);
252263 dnnl::memory::dims weight_dims = weight_dims_;
253264 if (groups > 1 ) {
254265 if (weight_dims_.size () == 5 ) {
255- weight_dims = {weight_dims_[ 0 ] * weight_dims_[1 ], weight_dims_[2 ], weight_dims_[3 ],
266+ weight_dims = {groups * weight_dims_[1 ], groups * weight_dims_[2 ], weight_dims_[3 ],
256267 weight_dims_[4 ]};
257268 } else {
258269 weight_dims[1 ] = weight_dims[1 ] * groups;
259270 }
260271 }
261- dnnl::memory::dims out_dims = str2dims (out_shape, input_size );
262- dnnl::memory::dims padding_dims = str2dims (paddings, 2 * (input_size - 2 ) );
272+ dnnl::memory::dims out_dims = str2dims (out_shape);
273+ dnnl::memory::dims padding_dims = str2dims (paddings);
263274 dnnl::memory::dims padding_dims_l (padding_dims.begin (),
264275 padding_dims.begin () + padding_dims.size () / 2 );
265276 dnnl::memory::dims padding_dims_r (padding_dims.end () - padding_dims.size () / 2 ,
266277 padding_dims.end ());
267- dnnl::memory::dims output_padding_dims = str2dims (output_paddings, input_size - 2 );
268- dnnl::memory::dims strides_dims = str2dims (strides, input_size - 2 );
269- dnnl::memory::dims dilates_dims = str2dims (dilates, input_size - 2 );
278+ dnnl::memory::dims output_padding_dims = str2dims (output_paddings);
279+ dnnl::memory::dims strides_dims = str2dims (strides);
280+ dnnl::memory::dims dilates_dims = str2dims (dilates, true );
270281
271282 dnnl::memory::dims input_dims = out_dims;
272283 if (out_dims[1 ] == weight_dims[0 ]) {
@@ -275,15 +286,14 @@ std::string get_optimal_layout_for_conv_transpose(int input_size, std::string we
275286 input_dims[1 ] = weight_dims[0 ];
276287 std::swap (weight_dims[0 ], weight_dims[1 ]);
277288 }
278- for (int i = 2 ; i < input_size ; i++) {
289+ for (int i = 2 ; i < out_dims. size () ; i++) {
279290 dnnl::memory::dim K = weight_dims[i];
280291 dnnl::memory::dim S = strides_dims[i - 2 ];
281- dnnl::memory::dim D = dilates_dims[i - 2 ] - 1 ;
292+ dnnl::memory::dim D = dilates_dims[i - 2 ];
282293 dnnl::memory::dim PL = padding_dims_l[i - 2 ];
283294 dnnl::memory::dim PR = padding_dims_r[i - 2 ];
284295 dnnl::memory::dim OP = output_padding_dims[i - 2 ];
285296 dnnl::memory::dim DK = 1 + (K - 1 ) * (D + 1 );
286- dilates_dims[i - 2 ] = D;
287297 input_dims[i] = (out_dims[i] - DK + PL + PR - OP) / S + 1 ;
288298 }
289299
@@ -325,14 +335,13 @@ std::string get_optimal_layout_for_conv_transpose(int input_size, std::string we
325335
326336TVM_REGISTER_GLOBAL (" relay.ir.get_optimal_layout_for_conv" )
327337 .set_body([](TVMArgs args, TVMRetValue* rv) {
328- *rv = get_optimal_layout_for_conv (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ],
329- args[6 ]);
338+ *rv = get_optimal_layout_for_conv (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ]);
330339 });
331340
332341TVM_REGISTER_GLOBAL (" relay.ir.get_optimal_layout_for_conv_transpose" )
333342 .set_body([](TVMArgs args, TVMRetValue* rv) {
334343 *rv = get_optimal_layout_for_conv_transpose (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ],
335- args[5 ], args[6 ], args[ 7 ] );
344+ args[5 ], args[6 ]);
336345 });
337346
338347} // namespace contrib
0 commit comments