2121from tvm import relay , transform
2222from tvm .driver .tvmc import TVMCException
2323
24- # ToMixedPrecision
25- ACC_DTYPE = "float32"
2624
25+ def generate_mixed_precision_rule (acc_dtype ):
26+ def _mixed_precision_rule (call_node : "relay.Call" , mixed_precision_type : str ):
27+ return [
28+ relay .transform .mixed_precision .MIXED_PRECISION_ALWAYS ,
29+ acc_dtype ,
30+ mixed_precision_type ,
31+ ]
2732
28- def mixed_precision_rule (call_node : "relay.Call" , mixed_precision_type : str ):
29- global ACC_DTYPE
30- return [
31- relay .transform .mixed_precision .MIXED_PRECISION_ALWAYS ,
32- ACC_DTYPE ,
33- mixed_precision_type ,
34- ]
33+ return _mixed_precision_rule
3534
3635
3736class MixedPrecision (object ):
3837 """Temporarily changes attr of ops to enable required precision."""
3938
40- def __init__ (self , ops ):
39+ def __init__ (self , ops , acc_type ):
4140 """Saves the required info for RAII pattern usage.
4241
4342 Parameters
4443 ----------
4544 ops : list
4645 list of operators
46+ acc_type: str
47+ Output or accumulation precision to be used.
4748 """
4849 self .older_attr = {}
4950 self .ops = ops
51+ self .acc_type = acc_type
5052 self .attr_key = "FTVMMixedPrecisionConversionType"
5153
5254 def __enter__ (self ):
5355 for op_name in self .ops :
5456 op = relay .op .get (op_name )
5557 self .older_attr [op_name ] = op .get_attr (self .attr_key )
5658 op .reset_attr (self .attr_key )
57- op .set_attr (self .attr_key , mixed_precision_rule )
59+ op .set_attr (self .attr_key , generate_mixed_precision_rule ( self . acc_type ) )
5860 return self
5961
6062 def __exit__ (self , ptype , value , trace ):
@@ -65,20 +67,18 @@ def __exit__(self, ptype, value, trace):
6567 op .set_attr (self .attr_key , self .older_attr [op_name ])
6668
6769
68- def convert_to_mixed_precision (
69- mod , ops = "nn.conv2d,nn.dense" , input_type = "float16" , out_type = "float16"
70- ):
70+ def convert_to_mixed_precision (mod , ops = None , calculation_type = "float16" , acc_type = "float16" ):
7171 """Converts the operator datatypes
7272
7373 Parameters
7474 ----------
7575 mod : tvm.IRModule
7676 The relay module to convert.
77- ops : str
77+ ops : list
7878 List of operators to be precision converted.
79- input_type : str
79+ calculation_type : str
8080 Input precision to be used.
81- output_type : str
81+ acc_type : str
8282 Output or accumulation precision to be used.
8383
8484 Returns
@@ -87,10 +87,10 @@ def convert_to_mixed_precision(
8787 The converted module.
8888 """
8989
90- global ACC_DTYPE
91- ACC_DTYPE = out_type
90+ if ops is None :
91+ ops = [ "nn.conv2d" , "nn.dense" ]
9292
93- with MixedPrecision (ops . split ( "," ) ):
93+ with MixedPrecision (ops , acc_type ):
9494 seq = transform .Sequential (
9595 [relay .transform .InferType (), relay .transform .ToMixedPrecision ()]
9696 )
@@ -103,7 +103,7 @@ def convert_to_mixed_precision(
103103 raise TVMCException ("Error converting mixed precision : {0}" .format (str (err )))
104104
105105
106- def convert_graph_layout (mod , desired_layout , ops = "nn.conv2d,nn.conv2d_transpose,qnn.conv2d" ):
106+ def convert_graph_layout (mod , desired_layout , ops = None ):
107107 """Alter the layout of the input graph.
108108
109109 Parameters
@@ -112,16 +112,18 @@ def convert_graph_layout(mod, desired_layout, ops="nn.conv2d,nn.conv2d_transpose
112112 The relay module to convert.
113113 desired_layout : str
114114 The layout to convert to.
115- ops : str
115+ ops : list
116116 List of operators to be layout converted.
117117
118118 Returns
119119 -------
120120 mod : tvm.IRModule
121121 The converted module.
122122 """
123+ if ops is None :
124+ ops = ["nn.conv2d" , "nn.conv2d_transpose" , "qnn.conv2d" ]
123125
124- desired_layouts = {op : [desired_layout , "default" ] for op in ops . split ( "," ) }
126+ desired_layouts = {op : [desired_layout , "default" ] for op in ops }
125127
126128 # Convert the layout of the graph where possible.
127129 seq = transform .Sequential (
@@ -164,9 +166,9 @@ def apply_graph_transforms(mod, args):
164166 if args .get ("mixed_precision" , False ):
165167 mod = convert_to_mixed_precision (
166168 mod ,
167- args .get ("mixed_precision_ops" , "nn.conv2d,nn.dense" ),
168- args .get ("mixed_precision_input" , "float16 " ),
169- args .get ("mixed_precision_output" , "float16 " ),
169+ args .get ("mixed_precision_ops" ),
170+ args .get ("mixed_precision_calculation_type " ),
171+ args .get ("mixed_precision_acc_type " ),
170172 )
171173 return mod
172174
@@ -176,26 +178,27 @@ def parse_graph_transform_args(args):
176178
177179 Parameters
178180 ----------
179- args: argparse.Namespace
180- Arguments from command line parser .
181+ args: argparse.Namespace or dict
182+ Arguments.
181183
182184 Returns
183185 -------
184186 transform_args : dict
185187 Graph transform arguments
186188 """
187189
188- args_dict = vars (args )
190+ if not isinstance (args , dict ):
191+ args = vars (args )
189192
190193 transform_args = [
191194 "desired_layout" ,
192195 "desired_layout_ops" ,
193196 "mixed_precision" ,
194197 "mixed_precision_ops" ,
195- "mixed_precision_input " ,
196- "mixed_precision_output " ,
198+ "mixed_precision_calculation_type " ,
199+ "mixed_precision_acc_type " ,
197200 ]
198- transform_args = {key : args_dict .get (key , None ) for key in transform_args }
201+ transform_args = {key : args .get (key , None ) for key in transform_args }
199202 return transform_args
200203
201204
@@ -211,7 +214,8 @@ def generate_transform_args(parser):
211214 )
212215 parser .add_argument (
213216 "--desired-layout-ops" ,
214- default = "nn.conv2d,nn.conv2d_transpose,qnn.conv2d" ,
217+ default = ["nn.conv2d" , "nn.conv2d_transpose" , "qnn.conv2d" ],
218+ nargs = "+" ,
215219 help = "List of operators to be layout converted." ,
216220 )
217221
@@ -223,18 +227,19 @@ def generate_transform_args(parser):
223227 )
224228 parser .add_argument (
225229 "--mixed-precision-ops" ,
226- default = "nn.conv2d,nn.dense" ,
230+ default = ["nn.conv2d" , "nn.dense" ],
231+ nargs = "+" ,
227232 help = "List of operators to be converted to mixed precision" ,
228233 )
229234 parser .add_argument (
230- "--mixed-precision-input " ,
235+ "--mixed-precision-calculation-type " ,
231236 choices = ["float16" , "float32" ],
232237 default = "float16" ,
233- help = "Input precision type" ,
238+ help = "Calculation precision type" ,
234239 )
235240 parser .add_argument (
236- "--mixed-precision-output " ,
241+ "--mixed-precision-acc-type " ,
237242 choices = ["float16" , "float32" ],
238243 default = "float16" ,
239- help = "Output or accumulator precision type" ,
244+ help = "Accumulator precision type" ,
240245 )
0 commit comments