3535import logging
3636
3737import tvm .ir
38- from tvm .relay import transform
39- from tvm .relay .build_module import bind_params_by_name
38+ from tvm import relay
4039
40+ from ... import _ffi_api
4141from ...dataflow_pattern import wildcard , is_op
4242from .register import register_pattern_table
4343
@@ -94,12 +94,12 @@ def _func_wrapper(expr):
9494
9595
9696def make_conv_pattern (conv_name , with_bias = True , with_eltwise = None ):
97- """Create patterns related to conv and deconv .
97+ """Create patterns related to conv and conv_transpose .
9898
9999 Parameters
100100 ----------
101101 with_bias : bool
102- Whether attach `bias_add` to `conv / deconv `.
102+ Whether attach `bias_add` to `conv / conv_transpose `.
103103 with_eltwise : str
104104 The attached elementwise post-op name.
105105 Returns
@@ -147,12 +147,12 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
147147 return dense_out
148148
149149
150- def make_dnnl_pattern (op , with_bias , with_eltwise ):
150+ def make_dnnl_pattern (op_name , with_bias , with_eltwise ):
151151 """Create dnnl patterns.
152152
153153 Parameters
154154 ----------
155- op : str
155+ op_name : str
156156 The first call node's op name.
157157 with_bias : bool
158158 Whether attach `bias_add` to `nn.dense`.
@@ -163,18 +163,20 @@ def make_dnnl_pattern(op, with_bias, with_eltwise):
163163 pattern : Tuple(pattern_name, CallPattern)
164164 Created pattern name, along with its CallPattern.
165165 """
166- pat_name = op .replace ("nn" , "dnnl" )
166+ pat_name = op_name .replace ("nn" , "dnnl" )
167+ if "_transpose" in op_name :
168+ pat_name = "dnnl.deconv" + op_name .split ("_" )[0 ][- 2 ::]
167169 pat_name += "_bias" if with_bias else ""
168170 pat_name += ("_" + with_eltwise .split ("." )[- 1 ]) if with_eltwise else ""
169- if "conv" in op :
170- dnnl_pattern = (pat_name , make_conv_pattern (op , with_bias , with_eltwise ))
171- elif op == "nn.dense" :
171+ if "conv" in op_name :
172+ dnnl_pattern = (pat_name , make_conv_pattern (op_name , with_bias , with_eltwise ))
173+ elif op_name == "nn.dense" :
172174 dnnl_pattern = (pat_name , make_dense_pattern (with_bias , with_eltwise ))
173175 else :
174176 logger .warning (
175177 "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
176178 "dense op are supported, but got %s." ,
177- op ,
179+ op_name ,
178180 )
179181 dnnl_pattern = ()
180182 return dnnl_pattern
@@ -207,39 +209,205 @@ def pattern_table():
207209 return dnnl_patterns
208210
209211
210- def partition_for_dnnl (mod , params = None ):
211- """Partition the graph greedily offloading supported operators to DNNL.
212+ def get_optimal_layout_for_conv (
213+ data_layout , kernel_layout , weight_shape , out_shape , paddings , strides , dilates , groups
214+ ):
215+ """Get the optimal layout of dnnl, given shape of conv2d.
212216
213217 Parameters
214218 ----------
215- mod : Module
216- The module to run passes on.
217- params : Optional[Dict[str, NDArray]]
218- Constant input parameters.
219+ data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
220+ : String
221+ Input argument.
222+
219223 Returns
220224 -------
221- mod : Module
222- Annotated and partitioned module .
225+ layouts : string
226+ The result .
223227 """
228+ return _ffi_api .get_optimal_layout_for_conv (
229+ data_layout ,
230+ kernel_layout ,
231+ weight_shape ,
232+ out_shape ,
233+ paddings ,
234+ strides ,
235+ dilates ,
236+ groups ,
237+ )
238+
239+
240+ def get_optimal_layout_for_conv_transpose (
241+ data_layout ,
242+ kernel_layout ,
243+ weight_shape ,
244+ out_shape ,
245+ paddings ,
246+ output_paddings ,
247+ strides ,
248+ dilates ,
249+ groups ,
250+ ):
251+ """Get the optimal layout of dnnl, given shape of tranposed conv2d.
252+
253+ Parameters
254+ ----------
255+ data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
256+ dilates, groups
257+ : Int, String
258+ Input argument.
259+
260+ Returns
261+ -------
262+ layouts : string
263+ The result.
264+ """
265+ return _ffi_api .get_optimal_layout_for_conv_transpose (
266+ data_layout ,
267+ kernel_layout ,
268+ weight_shape ,
269+ out_shape ,
270+ paddings ,
271+ output_paddings ,
272+ strides ,
273+ dilates ,
274+ groups ,
275+ )
276+
277+
278+ def get_shape (tensor ):
279+ """Get tensor's shape."""
280+ if isinstance (tensor , relay .expr .Var ):
281+ return tensor .type_annotation .concrete_shape
282+ if isinstance (tensor , relay .expr .Constant ):
283+ return tensor .data .shape
284+ if isinstance (tensor , tvm .ir .tensor_type .TensorType ):
285+ return tensor .concrete_shape
286+ if isinstance (tensor , tvm .ir .container .Array ):
287+ return tensor [- 1 ].shape
288+ if isinstance (tensor , relay .expr .Call ):
289+ return tensor .checked_type .shape
290+ raise TypeError ("Unsupport data type: %s" % type (tensor ))
291+
224292
225- if params :
226- mod ["main" ] = bind_params_by_name (mod ["main" ], params )
227- seq = tvm .transform .Sequential (
228- [
229- transform .CanonicalizeOps (),
230- transform .InferType (),
231- transform .SimplifyInference (),
232- transform .FoldConstant (),
233- transform .FoldScaleAxis (),
234- # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
235- transform .SimplifyExpr (),
236- transform .FoldConstant (),
237- transform .MergeComposite (pattern_table ()),
238- transform .AnnotateTarget ("dnnl" ),
239- transform .MergeCompilerRegions (),
240- transform .PartitionGraph (),
241- ]
293+ def tag2layout (input_data , is_weight = False , conv_type = "Conv1D" ):
294+ """Transfer layout, denoted with `a, b, c, d, e`,
295+ into valid layout (NCHW / OIHW) of TVM."""
296+ if "Conv1D" in conv_type :
297+ data_dic = {"a" : "N" , "b" : "C" , "c" : "W" }
298+ weight_dic = {"a" : "O" , "b" : "I" , "c" : "W" , "d" : "G" }
299+ elif "Conv2D" in conv_type :
300+ data_dic = {"a" : "N" , "b" : "C" , "c" : "H" , "d" : "W" }
301+ weight_dic = {"a" : "O" , "b" : "I" , "c" : "H" , "d" : "W" }
302+ if "e" in input_data :
303+ weight_dic = {"a" : "G" , "b" : "O" , "c" : "I" , "d" : "H" , "e" : "W" }
304+ elif "Conv3D" in conv_type :
305+ data_dic = {"a" : "N" , "b" : "C" , "c" : "D" , "d" : "H" , "e" : "W" }
306+ weight_dic = {"a" : "O" , "b" : "I" , "c" : "D" , "d" : "H" , "e" : "W" , "f" : "G" }
307+
308+ dic = weight_dic if is_weight else data_dic
309+ res = ""
310+
311+ for i in input_data :
312+ if i .isupper ():
313+ i = i .lower ()
314+ res += dic [i ]
315+ dic [i ] = dic [i ].lower ()
316+ elif i .islower ():
317+ res += dic [i ]
318+ elif i .isdigit ():
319+ res += i
320+ else :
321+ raise ValueError ("Unsupport layout format: %s" % input_data )
322+ return res
323+
324+
325+ def legalize_group_conv (attrs , inputs , types ):
326+ """Legalize group conv / conv_transpose calculation.
327+ Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
328+ groups = attrs .groups
329+ data , weight = inputs
330+ if groups == 1 :
331+ if "Transpose" not in type (attrs ).__name__ :
332+ return relay .nn .conv2d (data , weight , ** attrs )
333+ return relay .nn .conv2d_transpose (data , weight , ** attrs )
334+ OC , IC , H , W = get_shape (weight )
335+ new_attrs = dict (attrs )
336+ weight = relay .reshape (weight , (groups , OC // groups , IC , H , W ))
337+ if "Transpose" not in type (attrs ).__name__ :
338+ new_attrs ["kernel_layout" ] = "GOIHW"
339+ return relay .nn .conv2d (data , weight , ** new_attrs )
340+ new_attrs ["kernel_layout" ] = "GIOHW"
341+ return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
342+
343+
344+ def alter_conv (attrs , inputs , tinfos , out_type ):
345+ """The convolution's layout auto-query func for dnnl."""
346+
347+ data , weight = inputs
348+ groups = str (attrs .groups )
349+ weight_shape = "," .join ([str (x ) for x in get_shape (weight )])
350+ out_shape = "," .join ([str (x ) for x in get_shape (out_type )])
351+ paddings = "," .join ([str (x ) for x in attrs .get_int_tuple ("padding" )])
352+ strides = "," .join ([str (x ) for x in attrs .get_int_tuple ("strides" )])
353+ dilates = "," .join ([str (x ) for x in attrs .get_int_tuple ("dilation" )])
354+ new_attrs = dict (attrs )
355+ conv_type = type (attrs ).__name__ .split ("Attrs" )[0 ]
356+
357+ res = get_optimal_layout_for_conv (
358+ attrs ["data_layout" ],
359+ attrs ["kernel_layout" ],
360+ weight_shape ,
361+ out_shape ,
362+ paddings ,
363+ strides ,
364+ dilates ,
365+ groups ,
242366 )
243- with tvm .transform .PassContext (opt_level = 3 ):
244- mod = seq (mod )
245- return mod
367+ src_df , weight_df , dst_df = res .split ("," )
368+ new_attrs ["data_layout" ] = tag2layout (src_df , is_weight = False , conv_type = conv_type )
369+ new_attrs ["kernel_layout" ] = tag2layout (weight_df , is_weight = True , conv_type = conv_type )
370+ new_attrs ["out_layout" ] = tag2layout (dst_df , is_weight = False , conv_type = conv_type )
371+
372+ if conv_type == "Conv1D" :
373+ return relay .nn .conv1d (data , weight , ** new_attrs )
374+ if conv_type == "Conv2D" :
375+ return relay .nn .conv2d (data , weight , ** new_attrs )
376+ return relay .nn .conv3d (data , weight , ** new_attrs )
377+
378+
379+ def alter_conv_transpose (attrs , inputs , tinfos , out_type ):
380+ """The transposed convolution's layout auto-query func for dnnl."""
381+
382+ data , weight = inputs
383+ weight_shape = "," .join ([str (x ) for x in get_shape (weight )])
384+ out_shape = "," .join ([str (x ) for x in get_shape (out_type )])
385+ paddings = "," .join ([str (x ) for x in attrs .get_int_tuple ("padding" )])
386+ output_paddings = "," .join ([str (x ) for x in attrs .get_int_tuple ("output_padding" )])
387+ strides = "," .join ([str (x ) for x in attrs .get_int_tuple ("strides" )])
388+ dilates = "," .join ([str (x ) for x in attrs .get_int_tuple ("dilation" )])
389+ groups = str (attrs .groups )
390+ new_attrs = dict (attrs )
391+ conv_type = type (attrs ).__name__ .split ("Attrs" )[0 ]
392+
393+ res = get_optimal_layout_for_conv_transpose (
394+ attrs ["data_layout" ],
395+ attrs ["kernel_layout" ],
396+ weight_shape ,
397+ out_shape ,
398+ paddings ,
399+ output_paddings ,
400+ strides ,
401+ dilates ,
402+ groups ,
403+ )
404+ src_df , weight_df , dst_df = res .split ("," )
405+ new_attrs ["data_layout" ] = tag2layout (src_df , is_weight = False , conv_type = conv_type )
406+ new_attrs ["kernel_layout" ] = tag2layout (weight_df , is_weight = True , conv_type = conv_type )
407+ new_attrs ["out_layout" ] = tag2layout (dst_df , is_weight = False , conv_type = conv_type )
408+
409+ if conv_type == "Conv1DTranspose" :
410+ return relay .nn .conv1d_transpose (data , weight , ** new_attrs )
411+ if conv_type == "Conv2DTranspose" :
412+ return relay .nn .conv2d_transpose (data , weight , ** new_attrs )
413+ return relay .nn .conv3d_transpose (data , weight , ** new_attrs )
0 commit comments