55# LICENSE file in the root directory of this source tree.
66
77import torch
8- import torch .nn as nn
98from executorch .backends .qualcomm .builders .utils import get_parameter , set_parameter
109from executorch .backends .qualcomm .utils .constants import QCOM_REQUANTIZE
11- from executorch .exir .dialects ._ops import ops as exir_ops
1210from executorch .exir .pass_base import ExportPass , PassResult
1311
1412from .utils import copy_meta
@@ -23,16 +21,43 @@ class ConvertConv1dToConv2d(ExportPass):
2321 def __init__ (self , edge_program : torch .export .ExportedProgram ):
2422 super (ConvertConv1dToConv2d , self ).__init__ ()
2523 self .edge_program = edge_program
24+ self .conv_op_map = {
25+ torch .ops .aten .conv1d .default : torch .ops .aten .conv2d .default ,
26+ torch .ops .aten .conv_transpose1d .default : torch .ops .aten .conv_transpose2d .input ,
27+ }
28+
29+ def append_qdq (
30+ self ,
31+ graph_module : torch .fx .GraphModule ,
32+ node : torch .fx .Node ,
33+ qdq_node : torch .fx .Node ,
34+ ):
35+ q_op = torch .ops .quantized_decomposed .quantize_per_tensor .default
36+ dq_op = torch .ops .quantized_decomposed .dequantize_per_tensor .default
37+ if qdq_node .target not in {q_op , dq_op }:
38+ return node
39+
40+ with graph_module .graph .inserting_after (node ):
41+ q_args = (node , * qdq_node .args [1 :])
42+ q_node = graph_module .graph .create_node ("call_function" , q_op , q_args )
43+ q_node .meta = copy_meta (node .meta )
44+ q_node .meta ["val" ] = q_node .meta ["val" ].to (q_args [- 1 ])
45+ with graph_module .graph .inserting_after (q_node ):
46+ dq_args = (q_node , * qdq_node .args [1 :])
47+ dq_node = graph_module .graph .create_node (
48+ "call_function" , dq_op , dq_args
49+ )
50+ dq_node .meta = copy_meta (node .meta )
51+
52+ return dq_node
2653
2754 def call (self , graph_module : torch .fx .GraphModule ):
2855 graph = graph_module .graph
29- conv_op = exir_ops .edge .aten .convolution .default
3056 for node in graph .nodes :
31- if node .target == conv_op and node .meta ["val" ].dim () == 3 :
32-
57+ if node .target in self .conv_op_map :
3358 input_node = node .args [0 ]
3459 with graph_module .graph .inserting_after (input_node ):
35- unsqueeze_op = exir_ops . edge .aten .unsqueeze_copy .default
60+ unsqueeze_op = torch . ops .aten .unsqueeze_copy .default
3661 unsqueeze_node = graph .create_node (
3762 "call_function" ,
3863 unsqueeze_op ,
@@ -44,52 +69,88 @@ def call(self, graph_module: torch.fx.GraphModule):
4469 unsqueeze_node .meta = copy_meta (
4570 input_node .meta , lambda m : {** m , "val" : m ["val" ].unsqueeze (2 )}
4671 )
72+ qdq_node_after_unsqueeze = self .append_qdq (
73+ graph_module = graph_module ,
74+ node = unsqueeze_node ,
75+ qdq_node = input_node ,
76+ )
4777
48- with graph_module .graph .inserting_after (unsqueeze_node ):
49-
50- filter_node = node .args [1 ]
78+ with graph_module .graph .inserting_after (qdq_node_after_unsqueeze ):
79+ filter_arg = node .args [1 ]
80+ filter_node = (
81+ filter_arg
82+ if filter_arg .op == "placeholder"
83+ else node .args [1 ].args [0 ]
84+ )
5185 filter_node .meta ["val" ] = (
5286 filter_node .meta ["val" ].unsqueeze (2 ).contiguous ()
5387 )
54- filter_tensor = get_parameter (filter_node , self .edge_program )
55- # Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
56- filter_tensor = nn .Parameter (filter_tensor .unsqueeze (2 ))
57- set_parameter (filter_tensor , filter_node , self .edge_program )
88+ filter_tensor = get_parameter (
89+ filter_node , self .edge_program
90+ ).unsqueeze (2 )
91+ set_parameter (
92+ (
93+ torch .nn .Parameter (filter_tensor )
94+ if filter_tensor .dtype == torch .float
95+ else filter_tensor
96+ ),
97+ filter_node ,
98+ self .edge_program ,
99+ )
58100
101+ num_args = len (node .args )
59102 bias_node = node .args [2 ]
60- stride = [1 ] + node .args [3 ]
61- padding = [0 ] + node .args [4 ]
62- dilation = [1 ] + node .args [5 ]
63- transpose = node .args [6 ]
64- output_padding = [0 ] + node .args [7 ]
65- groups = node .args [8 ]
66-
67- conv2d_node = graph .create_node (
68- "call_function" ,
69- conv_op ,
70- (
71- unsqueeze_node ,
72- filter_node ,
103+ stride = [1 ] + node .args [3 ] if num_args > 3 else [1 , 1 ]
104+ padding = [0 ] + node .args [4 ] if num_args > 4 else [0 , 0 ]
105+ if node .target == torch .ops .aten .conv1d .default :
106+ dilation = [1 ] + node .args [5 ] if num_args > 5 else [1 , 1 ]
107+ groups = node .args [6 ] if num_args > 5 else 1
108+ conv_args = (
109+ qdq_node_after_unsqueeze ,
110+ node .args [1 ],
73111 bias_node ,
74112 stride ,
75113 padding ,
76114 dilation ,
77- transpose ,
115+ groups ,
116+ )
117+ else :
118+ output_padding = (
119+ [0 ] + node .args [5 ] if num_args > 5 else [0 , 0 ]
120+ )
121+ groups = node .args [6 ] if num_args > 6 else 1
122+ dilation = [1 ] + node .args [7 ] if num_args > 7 else [1 , 1 ]
123+ conv_args = (
124+ qdq_node_after_unsqueeze ,
125+ node .args [1 ],
126+ bias_node ,
127+ stride ,
128+ padding ,
78129 output_padding ,
79130 groups ,
80- ),
131+ dilation ,
132+ )
133+ conv2d_node = graph .create_node (
134+ "call_function" ,
135+ self .conv_op_map [node .target ],
136+ conv_args ,
81137 )
82138 conv2d_node .meta = copy_meta (
83139 node .meta , lambda m : {** m , "val" : m ["val" ].unsqueeze (2 )}
84140 )
141+ qdq_node_after_conv2d = self .append_qdq (
142+ graph_module = graph_module ,
143+ node = conv2d_node ,
144+ qdq_node = list (node .users )[0 ],
145+ )
85146
86- with graph_module .graph .inserting_after (conv2d_node ):
87- squeeze_op = exir_ops . edge .aten .squeeze_copy .dims
147+ with graph_module .graph .inserting_after (qdq_node_after_conv2d ):
148+ squeeze_op = torch . ops .aten .squeeze_copy .dims
88149 squeeze_node = graph .create_node (
89150 "call_function" ,
90151 squeeze_op ,
91152 (
92- conv2d_node ,
153+ qdq_node_after_conv2d ,
93154 [2 ],
94155 ),
95156 )
@@ -102,8 +163,10 @@ def call(self, graph_module: torch.fx.GraphModule):
102163 QCOM_REQUANTIZE
103164 ]
104165 conv2d_node .meta .pop (QCOM_REQUANTIZE , None )
166+
105167 for user in node .users .copy ():
106168 user .replace_input_with (node , squeeze_node )
169+
107170 graph .eliminate_dead_code ()
108171 graph_module .recompile ()
109172 return PassResult (graph_module , True )
0 commit comments