55
66# pyre-unsafe
77
8- # Utiliy functions for TOSA quantized lowerings
8+ # Utility functions for TOSA quantized lowerings
99
1010import math
1111
@@ -27,11 +27,11 @@ def insert_rescale_ops_to_int32_maxscale(
2727 tosa_graph : Any , inputs : list [TosaArg ], node : Node , tosa_spec = None
2828) -> tuple [list [Any ], float ]:
2929 """For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale))
30- compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision
30+ compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision
3131 for the computation without overflowing.
3232
3333 Returns a list of the rescaled nodes and the scale factor used,
34- needed by rescale_node_back_to_int8 .
34+ needed by insert_rescale_op_to_int8 .
3535 """
3636
3737 if len (inputs ) > 2 :
@@ -86,7 +86,7 @@ def insert_rescale_ops_to_int32(
8686 The scales are adjusted using the smallest scale of all 'nodes'.
8787
8888 Returns a list of the rescaled nodes and the scale factor used,
89- needed by rescale_node_back_to_int8 .
89+ needed by insert_rescale_op_to_int8 .
9090
9191 This functions is used in serialization to TOSA for target ops that are
9292 handled by the DQ/D folding pass, which stores the quantization parameters
@@ -134,7 +134,59 @@ def insert_rescale_op_to_int8(
134134 Parameters:
135135 node: The original node that is being handled by the rescales.
136136 last_tensor:the tosa tensor to rescale back.
137- scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
137+ scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
138+ compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
139+ tosa_graph: the tosa_graph to manipulate.
140+
141+ This functions is used in serialization to TOSA for target ops that are
142+ handled by the DQ/D folding pass, which stores the quantization parameters
143+ in the node meta dict.
144+ """
145+ _insert_rescale_op_to_dtype (
146+ tosa_graph , last_tensor , scale , node , ts .DType .INT8 , compute_rescale , tosa_spec
147+ )
148+
149+
150+ def insert_rescale_op_to_int16 (
151+ tosa_graph : Any ,
152+ last_tensor : TosaArg ,
153+ scale : float ,
154+ node : Node ,
155+ compute_rescale = True ,
156+ tosa_spec = None ,
157+ ) -> None :
158+ """Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
159+ Parameters:
160+ node: The original node that is being handled by the rescales.
161+ last_tensor:the tosa tensor to rescale back.
162+ scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
163+ compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
164+ tosa_graph: the tosa_graph to manipulate.
165+
166+ This functions is used in serialization to TOSA for target ops that are
167+ handled by the DQ/D folding pass, which stores the quantization parameters
168+ in the node meta dict.
169+ """
170+ _insert_rescale_op_to_dtype (
171+ tosa_graph , last_tensor , scale , node , ts .DType .INT16 , compute_rescale , tosa_spec
172+ )
173+
174+
175+ def _insert_rescale_op_to_dtype (
176+ tosa_graph : Any ,
177+ last_tensor : TosaArg ,
178+ scale : float ,
179+ node : Node ,
180+ output_dtype : Any ,
181+ compute_rescale = True ,
182+ tosa_spec = None ,
183+ ) -> None :
184+ """Common implementation for rescaling nodes back to a specific dtype.
185+ Parameters:
186+ node: The original node that is being handled by the rescales.
187+ last_tensor:the tosa tensor to rescale back.
188+ scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
189+ output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
138190 compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
139191 tosa_graph: the tosa_graph to manipulate.
140192
@@ -156,20 +208,21 @@ def insert_rescale_op_to_int8(
156208 else :
157209 output_rescale_scale = scale
158210
159- # Rescale Back to INT8
160- build_rescale_from_int32 (
211+ # Rescale Back to the specified dtype
212+ build_rescale_from_int32_to_dtype (
161213 tosa_graph ,
162214 last_tensor ,
163215 node .name ,
164216 qargs_out .get_zp_per_tensor (),
165217 output_rescale_scale ,
218+ output_dtype ,
166219 tosa_spec = tosa_spec ,
167220 )
168221
169222
170223# TOSA uses the RESCALE operation to scale between values with differing precision.
171224# The RESCALE operator is defined using an integer multiply, add, and shift.
172- # This utility function is for calculating the multier and shift given a scale.
225+ # This utility function is for calculating the multiplier and shift given a scale.
173226# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
174227def compute_multiplier_and_shift (
175228 scales : list [float ], scaleWidth : int = 32
@@ -214,7 +267,7 @@ def compute_multiplier_and_shift(
214267 return multipliers , shifts
215268
216269
217- # For TOSA spec v1.0 RESCALE operator requires multipler , shifts, input_zp and output_zp to be
270+ # For TOSA spec v1.0 RESCALE operator requires multiplier , shifts, input_zp and output_zp to be
218271# const inputs. Create constant operators from the data already initialized.
219272def create_const_ops_for_rescale (
220273 tosa_fb ,
@@ -335,14 +388,55 @@ def build_rescale_from_int32(
335388 per_channel : bool = False ,
336389 tosa_spec = None ,
337390) -> None :
391+ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
392+ # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
393+ build_rescale_from_int32_to_dtype (
394+ tosa_fb ,
395+ input_node ,
396+ output_name ,
397+ output_zp ,
398+ rescale_scale ,
399+ ts .DType .INT8 ,
400+ is_scale32 ,
401+ is_double_round ,
402+ per_channel ,
403+ tosa_spec ,
404+ )
405+
406+ return
407+
408+
409+ def build_rescale_from_int32_to_dtype (
410+ tosa_fb : Any ,
411+ input_node : TosaArg ,
412+ output_name : str ,
413+ output_zp : int ,
414+ rescale_scale : float ,
415+ output_dtype : Any ,
416+ is_scale32 : bool = True ,
417+ is_double_round : bool = False ,
418+ per_channel : bool = False ,
419+ tosa_spec = None ,
420+ ) -> None :
421+ """Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).
422+
423+ Parameters:
424+ tosa_fb: The TOSA serializer
425+ input_node: Input tensor (should be INT32)
426+ output_name: Name for the output tensor
427+ output_zp: Output zero point
428+ rescale_scale: Rescaling factor
429+ output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
430+ Other parameters: Standard rescale parameters
431+ """
338432 # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
339433 # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
340434 build_rescale (
341435 tosa_fb ,
342436 [rescale_scale ],
343437 input_node ,
344438 output_name = output_name ,
345- output_type = ts . DType . INT8 ,
439+ output_type = output_dtype ,
346440 input_zp = [0 ],
347441 output_zp = [output_zp ],
348442 rounding_mode = RoundingMode .SINGLE_ROUND ,
0 commit comments