@@ -211,14 +211,16 @@ def pattern_table():
211211 return dnnl_patterns
212212
213213
214- def get_optimal_layout_for_conv (data_layout , kernel_layout , weight_shape ,
215- out_shape , paddings , strides , dilates , groups ):
214+ def get_optimal_layout_for_conv (
215+ data_layout , kernel_layout , weight_shape , out_shape , paddings , strides , dilates , groups
216+ ):
216217 """Get the optimal layout of dnnl, given shape of conv2d.
217218
218219 Parameters
219220 ----------
220- data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String
221- Input argument.
221+ data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
222+ : String
223+ Input argument.
222224
223225 Returns
224226 -------
@@ -238,13 +240,22 @@ def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
238240
239241
240242def get_optimal_layout_for_conv_transpose (
241- data_layout , kernel_layout , weight_shape , out_shape , paddings , output_paddings , strides , dilates , groups
243+ data_layout ,
244+ kernel_layout ,
245+ weight_shape ,
246+ out_shape ,
247+ paddings ,
248+ output_paddings ,
249+ strides ,
250+ dilates ,
251+ groups ,
242252):
243253 """Get the optimal layout of dnnl, given shape of tranposed conv2d.
244254
245255 Parameters
246256 ----------
247- data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
257+ data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
258+ dilates, groups
248259 : Int, String
249260 Input argument.
250261
@@ -255,7 +266,7 @@ def get_optimal_layout_for_conv_transpose(
255266 """
256267 return _ffi_api .get_optimal_layout_for_conv_transpose (
257268 data_layout ,
258- kernel_layout ,
269+ kernel_layout ,
259270 weight_shape ,
260271 out_shape ,
261272 paddings ,
@@ -270,16 +281,15 @@ def get_shape(tensor):
270281 """Get tensor's shape."""
271282 if isinstance (tensor , relay .expr .Var ):
272283 return tensor .type_annotation .concrete_shape
273- elif isinstance (tensor , relay .expr .Constant ):
284+ if isinstance (tensor , relay .expr .Constant ):
274285 return tensor .data .shape
275- elif isinstance (tensor , tvm .ir .tensor_type .TensorType ):
286+ if isinstance (tensor , tvm .ir .tensor_type .TensorType ):
276287 return tensor .concrete_shape
277- elif isinstance (tensor , tvm .ir .container .Array ):
288+ if isinstance (tensor , tvm .ir .container .Array ):
278289 return tensor [- 1 ].shape
279- elif isinstance (tensor , relay .expr .Call ):
290+ if isinstance (tensor , relay .expr .Call ):
280291 return tensor .checked_type .shape
281- else :
282- raise TypeError ("Unsupport data type: %s" % type (tensor ))
292+ raise TypeError ("Unsupport data type: %s" % type (tensor ))
283293
284294
285295def tag2layout (input_data , is_weight = False , conv_type = "Conv1D" ):
@@ -318,16 +328,14 @@ def legalize_group_conv(attrs, inputs, types):
318328 """Legalize group conv / conv_transpose calculation.
319329 Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
320330 groups = attrs .groups
321- if groups == 1 :
322- return
323- data , weight = inputs
324- OC , IC , H , W = get_shape (weight )
325- new_attrs = dict (attrs )
326- weight = relay .reshape (weight , (groups , OC // groups , IC , H , W ))
327- if "Transpose" not in type (attrs ).__name__ :
328- new_attrs ["kernel_layout" ] = "GOIHW"
329- return relay .nn .conv2d (data , weight , ** new_attrs )
330- else :
331+ if groups > 1 :
332+ data , weight = inputs
333+ OC , IC , H , W = get_shape (weight )
334+ new_attrs = dict (attrs )
335+ weight = relay .reshape (weight , (groups , OC // groups , IC , H , W ))
336+ if "Transpose" not in type (attrs ).__name__ :
337+ new_attrs ["kernel_layout" ] = "GOIHW"
338+ return relay .nn .conv2d (data , weight , ** new_attrs )
331339 new_attrs ["kernel_layout" ] = "GIOHW"
332340 return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
333341
@@ -346,21 +354,25 @@ def alter_conv(attrs, inputs, tinfos, out_type):
346354 conv_type = type (attrs ).__name__ .split ("Attrs" )[0 ]
347355
348356 res = get_optimal_layout_for_conv (
349- attrs ["data_layout" ], attrs ["kernel_layout" ], weight_shape , out_shape , paddings ,
350- strides , dilates , groups ,
357+ attrs ["data_layout" ],
358+ attrs ["kernel_layout" ],
359+ weight_shape ,
360+ out_shape ,
361+ paddings ,
362+ strides ,
363+ dilates ,
364+ groups ,
351365 )
352366 src_df , weight_df , dst_df = res .split ("," )
353367 new_attrs ["data_layout" ] = tag2layout (src_df , is_weight = False , conv_type = conv_type )
354- new_attrs ["kernel_layout" ] = tag2layout (
355- weight_df , is_weight = True , conv_type = conv_type
356- )
368+ new_attrs ["kernel_layout" ] = tag2layout (weight_df , is_weight = True , conv_type = conv_type )
357369 new_attrs ["out_layout" ] = tag2layout (dst_df , is_weight = False , conv_type = conv_type )
358370
359371 if conv_type == "Conv1D" :
360372 return relay .nn .conv1d (data , weight , ** new_attrs )
361- elif conv_type == "Conv2D" :
373+ if conv_type == "Conv2D" :
362374 return relay .nn .conv2d (data , weight , ** new_attrs )
363- elif conv_type == "Conv3D" :
375+ if conv_type == "Conv3D" :
364376 return relay .nn .conv3d (data , weight , ** new_attrs )
365377
366378
@@ -380,7 +392,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
380392
381393 res = get_optimal_layout_for_conv_transpose (
382394 attrs ["data_layout" ],
383- attrs ["kernel_layout" ],
395+ attrs ["kernel_layout" ],
384396 weight_shape ,
385397 out_shape ,
386398 paddings ,
@@ -391,16 +403,14 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
391403 )
392404 src_df , weight_df , dst_df = res .split ("," )
393405 new_attrs ["data_layout" ] = tag2layout (src_df , is_weight = False , conv_type = conv_type )
394- new_attrs ["kernel_layout" ] = tag2layout (
395- weight_df , is_weight = True , conv_type = conv_type
396- )
406+ new_attrs ["kernel_layout" ] = tag2layout (weight_df , is_weight = True , conv_type = conv_type )
397407 new_attrs ["out_layout" ] = tag2layout (dst_df , is_weight = False , conv_type = conv_type )
398408
399409 if conv_type == "Conv1DTranspose" :
400410 return relay .nn .conv1d_transpose (data , weight , ** new_attrs )
401- elif conv_type == "Conv2DTranspose" :
411+ if conv_type == "Conv2DTranspose" :
402412 return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
403- elif conv_type == "Conv3DTranspose" :
413+ if conv_type == "Conv3DTranspose" :
404414 return relay .nn .conv3d_transpose (data , weight , ** new_attrs )
405415
406416
@@ -418,10 +428,10 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
418428 mod : Module
419429 Annotated and partitioned module.
420430 """
431+ from tvm .relay .testing .temp_op_attr import TempOpAttr
421432
422433 if params :
423434 mod ["main" ] = bind_params_by_name (mod ["main" ], params )
424- from tvm .relay .testing .temp_op_attr import TempOpAttr
425435
426436 with TempOpAttr ("nn.conv2d" , "FTVMLegalize" , legalize_group_conv ):
427437 with TempOpAttr ("nn.conv2d_transpose" , "FTVMLegalize" , legalize_group_conv ):
@@ -443,8 +453,6 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
443453 with tvm .transform .PassContext (opt_level = 3 ):
444454 mod = seq (mod )
445455 if alter_layout :
446- from tvm .relay .testing .temp_op_attr import TempOpAttr
447-
448456 with TempOpAttr ("nn.conv1d" , "FTVMAlterOpLayout" , alter_conv ):
449457 with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv ):
450458 with TempOpAttr ("nn.conv3d" , "FTVMAlterOpLayout" , alter_conv ):
0 commit comments