99import  executorch .backends .qualcomm .python .PyQnnWrapperAdaptor  as  PyQnnWrapper 
1010import  numpy  as  np 
1111import  torch 
12- from  executorch .backends .qualcomm .utils .constants  import  QCOM_DATA 
12+ from  executorch .backends .qualcomm .utils .constants  import  QCOM_DATA ,  QCOM_QUANT_ATTRS 
1313
14- from  .node_visitor  import  NodeVisitor 
14+ from  .node_visitor  import  NodeVisitor ,  PER_CHANNEL_ENCODING 
1515from  .node_visitor_manager  import  register_node_visitor 
1616from  .qnn_constants  import  (
1717    OpConv2d ,
@@ -101,6 +101,29 @@ def _add_conv_op_parameter(
101101
102102        return  conv_op 
103103
104+     def  _reduce_bias_scales (
105+         self ,
106+         node : torch .fx .Node ,
107+         filter_node : torch .fx .Node ,
108+         bias_node : torch .fx .Node ,
109+         groups : int ,
110+     ):
111+         """_summary_ 
112+         If transpose_conv has groups, need special handle for bias_node's per channel quant. 
113+         Check _derived_bias_quant_spec under backends/qualcomm/quantizer/qconfig.py for more info. 
114+         """ 
115+ 
116+         filter_scales  =  filter_node .meta [QCOM_QUANT_ATTRS ]["scales" ]
117+         bias_scales  =  bias_node .meta [QCOM_QUANT_ATTRS ]["scales" ]
118+         bias_zero_points  =  bias_node .meta [QCOM_QUANT_ATTRS ]["zero_points" ]
119+ 
120+         # Adding this condition to prevent reduce twice: op_validation and qnn_preprocess 
121+         if  filter_scales .numel () !=  bias_scales .numel ():
122+             bias_scales  =  bias_scales .view (- 1 , groups )[:, 0 ]
123+             bias_zero_points  =  bias_zero_points .view (- 1 , groups )[:, 0 ]
124+             bias_node .meta [QCOM_QUANT_ATTRS ]["scales" ] =  bias_scales 
125+             bias_node .meta [QCOM_QUANT_ATTRS ]["zero_points" ] =  bias_zero_points 
126+ 
104127    def  define_node (
105128        self ,
106129        node : torch .fx .Node ,
@@ -127,8 +150,15 @@ def define_node(
127150
128151        filter_node  =  self .get_node (node .args [1 ])
129152        filter_tensor  =  get_parameter (filter_node , self .edge_program )
153+ 
154+         stride  =  cast (List [int ], node .args [3 ])
155+         padding  =  cast (List [int ], node .args [4 ])
156+         dilation  =  cast (List [int ], node .args [5 ])
157+         output_padding  =  cast (List [int ], node .args [7 ])
158+         groups  =  cast (int , node .args [8 ])
159+ 
130160        # weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d), 
131-         # yet QNN is HWIO or DHWIO 
161+         # yet QNN is HWIO or DHWIO for both conv and conv_transpose.  
132162        is_transpose_conv  =  cast (bool , node .args [6 ])
133163        if  is_conv2d :
134164            filter_axis_order  =  (2 , 3 , 0 , 1 ) if  is_transpose_conv  else  (2 , 3 , 1 , 0 )
@@ -147,6 +177,16 @@ def define_node(
147177        conv_input_tensors  =  [input_tensor_wrapper , filter_tensor_wrapper ]
148178        if  node .args [2 ] is  not None :
149179            bias_node  =  self .get_node (node .args [2 ])
180+             # TODO: Double check on condition below once QNN supports transpose_conv with block_quant. 
181+             # By checking node.args[1].target, only allow per_channel_quant to go through and bypass block_quant. 
182+             if  (
183+                 is_transpose_conv 
184+                 and  groups  !=  1 
185+                 and  bias_node .meta .get (QCOM_QUANT_ATTRS ) is  not None 
186+                 and  node .args [1 ].target  in  PER_CHANNEL_ENCODING 
187+             ):
188+                 self ._reduce_bias_scales (node , filter_node , bias_node , groups )
189+ 
150190            bias_tensor  =  get_parameter (bias_node , self .edge_program )
151191            bias_tensor_wrapper  =  self .define_tensor (
152192                bias_node ,
@@ -156,7 +196,6 @@ def define_node(
156196                nodes_to_wrappers ,
157197            )
158198            conv_input_tensors .append (bias_tensor_wrapper )
159- 
160199        output_tensor  =  self .get_tensor (node , node )
161200        output_tensor_wrapper  =  self .define_tensor (
162201            node ,
@@ -167,11 +206,6 @@ def define_node(
167206        )
168207        conv_output_tensors  =  [output_tensor_wrapper ]
169208
170-         stride  =  cast (List [int ], node .args [3 ])
171-         padding  =  cast (List [int ], node .args [4 ])
172-         dilation  =  cast (List [int ], node .args [5 ])
173-         output_padding  =  cast (List [int ], node .args [7 ])
174-         groups  =  cast (int , node .args [8 ])
175209        # Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout) 
176210        group_input_channels  =  filter_tensor .shape [- 2 ]
177211        group_output_channels  =  int (filter_tensor .shape [- 1 ] /  groups )
0 commit comments