44# LICENSE file in the root directory of this source tree.
55
66
7- import itertools
87from typing import Set , Type
98
109import torch
1716 is_buffer ,
1817 is_param ,
1918)
20- from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
21- get_input_qparams ,
22- get_output_qparams ,
23- )
2419from executorch .backends .arm .constants import HWCM_ORDER , NHWC_INVERSE_ORDER
2520from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
2621from executorch .backends .transforms .utils import create_constant_placeholder
@@ -161,40 +156,6 @@ def _add_bias(
161156 node .update_arg (2 , bias_node )
162157 return bias_node
163158
164- def insert_output_rescale (self , graph_module , node ):
165- input_qparams = get_input_qparams (node )
166- output_qparams = get_output_qparams (node )[0 ]
167- weight_qparams = input_qparams [1 ]
168- input_qparams = input_qparams [0 ]
169- is_per_channel = weight_qparams .per_channel
170- if is_per_channel :
171- weight_scale = weight_qparams .get_scale_per_channel ()
172- else :
173- weight_scale = [weight_qparams .get_scale_per_tensor ()]
174- input_scale = input_qparams .get_scale_per_tensor ()
175- post_conv2d_scale = [
176- (inp * w ) / out
177- for inp , w , out in zip (
178- itertools .cycle ([input_scale ]),
179- weight_scale ,
180- itertools .cycle ([output_qparams .get_scale_per_tensor ()]),
181- )
182- ]
183- with graph_module .graph .inserting_after (node ):
184- rescale_node = create_node (
185- graph = graph_module .graph ,
186- op_target = exir_ops .backend .tosa .RESCALE .default ,
187- args = (
188- node ,
189- output_qparams .dtype ,
190- post_conv2d_scale ,
191- 0 ,
192- output_qparams .get_zp_per_tensor (),
193- ),
194- from_node = node ,
195- )
196- return rescale_node
197-
198159 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
199160 modified = False
200161 for node in graph_module .graph .nodes :
@@ -219,20 +180,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219180 ) = node .args
220181
221182 pad = [val for val in pad for _ in (0 , 1 )]
222- input_fake_tensor = get_first_fake_tensor (x )
223- weight_fake_tensor = get_first_fake_tensor (weight )
183+ input_shape = get_first_fake_tensor (x ). shape
184+ weight_shape = get_first_fake_tensor (weight ). shape
224185 # Adjust the pad value if needed to meet the
225186 # strict convolution output shape calculation.
226187 pad [1 ] = self ._adjust_pad_if_needed (
227- input_fake_tensor . shape [2 ],
228- weight_fake_tensor . shape [2 ],
188+ input_shape [2 ],
189+ weight_shape [2 ],
229190 stride [0 ],
230191 pad [1 ],
231192 dilation [0 ],
232193 )
233194 pad [3 ] = self ._adjust_pad_if_needed (
234- input_fake_tensor . shape [3 ],
235- weight_fake_tensor . shape [3 ],
195+ input_shape [3 ],
196+ weight_shape [3 ],
236197 stride [1 ],
237198 pad [3 ],
238199 dilation [1 ],
@@ -243,8 +204,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
243204
244205 if self ._is_depthwise_conv2d (node ):
245206 target_op = exir_ops .backend .tosa .DEPTHWISE_CONV2D .default
246- self ._reshape_weights (weight , input_fake_tensor .shape [1 ])
247- weight_fake_tensor = get_first_fake_tensor (weight )
207+ self ._reshape_weights (weight , input_shape [1 ])
248208 else :
249209 target_op = exir_ops .backend .tosa .CONV2D .default
250210
@@ -267,29 +227,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
267227 args = conv2d_args ,
268228 from_node = node ,
269229 )
270- bias_fake_tensor = get_first_fake_tensor (bias ) if bias else None
271- tosa_node_fake_tensor = target_op (
272- input_fake_tensor ,
273- weight_fake_tensor ,
274- bias_fake_tensor ,
275- * conv2d_args [3 :],
276- )
277230
278- if (
279- tosa_node_fake_tensor .dtype == torch .int32
280- and input_fake_tensor .dtype == torch .int8
281- ) or (
282- tosa_node_fake_tensor .dtype == torch .int32
283- and input_fake_tensor .dtype == torch .int16
284- ):
285- output_rescale = self .insert_output_rescale (graph_module , tosa_op )
286- node .replace_all_uses_with (output_rescale )
287- if input_fake_tensor .dtype == torch .int16 :
288- tosa_op .meta [TosaSpecialDtype .meta_key ()] = TosaSpecialDtype .INT48
289- else :
290231 node .replace_all_uses_with (tosa_op )
291-
292- graph_module .graph .erase_node (node )
232+ graph_module .graph .erase_node (node )
293233
294234 if modified :
295235 graph_module .recompile ()
0 commit comments