3838from tvm import relay
3939from tvm .relay import transform
4040from tvm .relay .build_module import bind_params_by_name
41+ from tvm .relay .testing .temp_op_attr import TempOpAttr
4142
4243from ... import _ffi_api
4344from ...dataflow_pattern import wildcard , is_op
@@ -211,14 +212,16 @@ def pattern_table():
211212 return dnnl_patterns
212213
213214
214- def get_optimal_layout_for_conv (data_layout , kernel_layout , weight_shape ,
215- out_shape , paddings , strides , dilates , groups ):
215+ def get_optimal_layout_for_conv (
216+ data_layout , kernel_layout , weight_shape , out_shape , paddings , strides , dilates , groups
217+ ):
216218 """Get the optimal layout of dnnl, given shape of conv2d.
217219
218220 Parameters
219221 ----------
220- data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String
221- Input argument.
222+ data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
223+ : String
224+ Input argument.
222225
223226 Returns
224227 -------
@@ -238,13 +241,22 @@ def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
238241
239242
240243def get_optimal_layout_for_conv_transpose (
241- data_layout , kernel_layout , weight_shape , out_shape , paddings , output_paddings , strides , dilates , groups
244+ data_layout ,
245+ kernel_layout ,
246+ weight_shape ,
247+ out_shape ,
248+ paddings ,
249+ output_paddings ,
250+ strides ,
251+ dilates ,
252+ groups ,
242253):
243254 """Get the optimal layout of dnnl, given shape of tranposed conv2d.
244255
245256 Parameters
246257 ----------
247- data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
258+ data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
259+ dilates, groups
248260 : Int, String
249261 Input argument.
250262
@@ -255,7 +267,7 @@ def get_optimal_layout_for_conv_transpose(
255267 """
256268 return _ffi_api .get_optimal_layout_for_conv_transpose (
257269 data_layout ,
258- kernel_layout ,
270+ kernel_layout ,
259271 weight_shape ,
260272 out_shape ,
261273 paddings ,
@@ -270,16 +282,15 @@ def get_shape(tensor):
270282 """Get tensor's shape."""
271283 if isinstance (tensor , relay .expr .Var ):
272284 return tensor .type_annotation .concrete_shape
273- elif isinstance (tensor , relay .expr .Constant ):
285+ if isinstance (tensor , relay .expr .Constant ):
274286 return tensor .data .shape
275- elif isinstance (tensor , tvm .ir .tensor_type .TensorType ):
287+ if isinstance (tensor , tvm .ir .tensor_type .TensorType ):
276288 return tensor .concrete_shape
277- elif isinstance (tensor , tvm .ir .container .Array ):
289+ if isinstance (tensor , tvm .ir .container .Array ):
278290 return tensor [- 1 ].shape
279- elif isinstance (tensor , relay .expr .Call ):
291+ if isinstance (tensor , relay .expr .Call ):
280292 return tensor .checked_type .shape
281- else :
282- raise TypeError ("Unsupport data type: %s" % type (tensor ))
293+ raise TypeError ("Unsupport data type: %s" % type (tensor ))
283294
284295
285296def tag2layout (input_data , is_weight = False , conv_type = "Conv1D" ):
@@ -318,18 +329,19 @@ def legalize_group_conv(attrs, inputs, types):
318329 """Legalize group conv / conv_transpose calculation.
319330 Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
320331 groups = attrs .groups
321- if groups == 1 :
322- return
323332 data , weight = inputs
333+ if groups == 1 :
334+ if "Transpose" not in type (attrs ).__name__ :
335+ return relay .nn .conv2d (data , weight , ** attrs )
336+ return relay .nn .conv2d_transpose (data , weight , ** attrs )
324337 OC , IC , H , W = get_shape (weight )
325338 new_attrs = dict (attrs )
326339 weight = relay .reshape (weight , (groups , OC // groups , IC , H , W ))
327340 if "Transpose" not in type (attrs ).__name__ :
328341 new_attrs ["kernel_layout" ] = "GOIHW"
329342 return relay .nn .conv2d (data , weight , ** new_attrs )
330- else :
331- new_attrs ["kernel_layout" ] = "GIOHW"
332- return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
343+ new_attrs ["kernel_layout" ] = "GIOHW"
344+ return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
333345
334346
335347def alter_conv (attrs , inputs , tinfos , out_type ):
@@ -346,22 +358,25 @@ def alter_conv(attrs, inputs, tinfos, out_type):
346358 conv_type = type (attrs ).__name__ .split ("Attrs" )[0 ]
347359
348360 res = get_optimal_layout_for_conv (
349- attrs ["data_layout" ], attrs ["kernel_layout" ], weight_shape , out_shape , paddings ,
350- strides , dilates , groups ,
361+ attrs ["data_layout" ],
362+ attrs ["kernel_layout" ],
363+ weight_shape ,
364+ out_shape ,
365+ paddings ,
366+ strides ,
367+ dilates ,
368+ groups ,
351369 )
352370 src_df , weight_df , dst_df = res .split ("," )
353371 new_attrs ["data_layout" ] = tag2layout (src_df , is_weight = False , conv_type = conv_type )
354- new_attrs ["kernel_layout" ] = tag2layout (
355- weight_df , is_weight = True , conv_type = conv_type
356- )
372+ new_attrs ["kernel_layout" ] = tag2layout (weight_df , is_weight = True , conv_type = conv_type )
357373 new_attrs ["out_layout" ] = tag2layout (dst_df , is_weight = False , conv_type = conv_type )
358374
359375 if conv_type == "Conv1D" :
360376 return relay .nn .conv1d (data , weight , ** new_attrs )
361- elif conv_type == "Conv2D" :
377+ if conv_type == "Conv2D" :
362378 return relay .nn .conv2d (data , weight , ** new_attrs )
363- elif conv_type == "Conv3D" :
364- return relay .nn .conv3d (data , weight , ** new_attrs )
379+ return relay .nn .conv3d (data , weight , ** new_attrs )
365380
366381
367382def alter_conv_transpose (attrs , inputs , tinfos , out_type ):
@@ -380,7 +395,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
380395
381396 res = get_optimal_layout_for_conv_transpose (
382397 attrs ["data_layout" ],
383- attrs ["kernel_layout" ],
398+ attrs ["kernel_layout" ],
384399 weight_shape ,
385400 out_shape ,
386401 paddings ,
@@ -391,17 +406,14 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
391406 )
392407 src_df , weight_df , dst_df = res .split ("," )
393408 new_attrs ["data_layout" ] = tag2layout (src_df , is_weight = False , conv_type = conv_type )
394- new_attrs ["kernel_layout" ] = tag2layout (
395- weight_df , is_weight = True , conv_type = conv_type
396- )
409+ new_attrs ["kernel_layout" ] = tag2layout (weight_df , is_weight = True , conv_type = conv_type )
397410 new_attrs ["out_layout" ] = tag2layout (dst_df , is_weight = False , conv_type = conv_type )
398411
399412 if conv_type == "Conv1DTranspose" :
400413 return relay .nn .conv1d_transpose (data , weight , ** new_attrs )
401- elif conv_type == "Conv2DTranspose" :
414+ if conv_type == "Conv2DTranspose" :
402415 return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
403- elif conv_type == "Conv3DTranspose" :
404- return relay .nn .conv3d_transpose (data , weight , ** new_attrs )
416+ return relay .nn .conv3d_transpose (data , weight , ** new_attrs )
405417
406418
407419def partition_for_dnnl (mod , params = None , alter_layout = True ):
@@ -418,10 +430,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
418430 mod : Module
419431 Annotated and partitioned module.
420432 """
421-
422433 if params :
423434 mod ["main" ] = bind_params_by_name (mod ["main" ], params )
424- from tvm .relay .testing .temp_op_attr import TempOpAttr
425435
426436 with TempOpAttr ("nn.conv2d" , "FTVMLegalize" , legalize_group_conv ):
427437 with TempOpAttr ("nn.conv2d_transpose" , "FTVMLegalize" , legalize_group_conv ):
@@ -443,8 +453,6 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
443453 with tvm .transform .PassContext (opt_level = 3 ):
444454 mod = seq (mod )
445455 if alter_layout :
446- from tvm .relay .testing .temp_op_attr import TempOpAttr
447-
448456 with TempOpAttr ("nn.conv1d" , "FTVMAlterOpLayout" , alter_conv ):
449457 with TempOpAttr ("nn.conv2d" , "FTVMAlterOpLayout" , alter_conv ):
450458 with TempOpAttr ("nn.conv3d" , "FTVMAlterOpLayout" , alter_conv ):
0 commit comments