@@ -32,12 +32,11 @@ def call(self, graph_module: torch.fx.GraphModule):
3232 graph = graph_module .graph
3333 constant_placeholders_to_delete = set ()
3434 for linear in graph .nodes :
35- # We want to discover a chain of linear -> batch_norm or addmm -> batch_norm .
36- # Only proceed if the current node is a linear or addmm node, and has a single
35+ # We want to discover a chain of linear -> batch_norm.
36+ # Only proceed if the current node is a linear node, and has a single
3737 # user/successor.
3838 if (
3939 linear .target != exir_ops .edge .aten .linear .default
40- and linear .target != exir_ops .edge .aten .addmm .default
4140 or len (linear .users ) != 1
4241 ):
4342 continue
@@ -51,34 +50,18 @@ def call(self, graph_module: torch.fx.GraphModule):
5150 ):
5251 continue
5352
53+ if not self .can_fuse (linear , bn , self .exported_program ):
54+ continue
55+
5456 # Get the parameters
5557 assert len (linear .args ) == 3
5658
57- if linear .target == exir_ops .edge .aten .addmm .default :
58- # addmm.args = (bias, input, weight)
59- linear_bias_arg = linear .args [0 ]
60- linear_input_arg = linear .args [1 ]
61- # Unwrap permute_copy to access weight parameter node
62- linear_weight_arg = FuseBatchNormWithLinearPass ._unwrap_node (
63- linear .args [2 ]
64- )
65- else :
66- # linear.args = (input, weight, bias)
67- linear_input_arg = linear .args [0 ]
68- linear_weight_arg = linear .args [1 ]
69- linear_bias_arg = linear .args [2 ]
70-
71- if not self .can_fuse (linear_weight_arg , bn , self .exported_program ):
72- continue
73-
74- linear_weight = get_param_tensor (self .exported_program , linear_weight_arg )
75- linear_weight_name = get_tensor_name (
76- self .exported_program , linear_weight_arg
77- )
59+ linear_weight = get_param_tensor (self .exported_program , linear .args [1 ])
60+ linear_weight_name = get_tensor_name (self .exported_program , linear .args [1 ])
7861 assert linear_weight is not None
7962
80- linear_bias = get_param_tensor (self .exported_program , linear_bias_arg )
81- linear_bias_name = get_tensor_name (self .exported_program , linear_bias_arg )
63+ linear_bias = get_param_tensor (self .exported_program , linear . args [ 2 ] )
64+ linear_bias_name = get_tensor_name (self .exported_program , linear . args [ 2 ] )
8265
8366 # Get the parameters from the batchnorm op
8467 assert (
@@ -112,12 +95,6 @@ def call(self, graph_module: torch.fx.GraphModule):
11295 bn_weight ,
11396 bn_bias ,
11497 )
115-
116- if linear .target == exir_ops .edge .aten .addmm .default :
117- # fuse_linear_bn_weights returns weight [out, in];
118- # permute_copy node was removed, so weight must be transposed to [in, out] for addmm
119- fused_weight = fused_weight .t ()
120-
12198 fused_weight_name = (linear_weight_name + "_fused_bn" ).replace ("." , "_" )
12299 if linear_bias_name == "" :
123100 fused_bias_name = (linear_weight_name + "_bias_fused_bn" ).replace (
@@ -130,7 +107,7 @@ def call(self, graph_module: torch.fx.GraphModule):
130107 # with the fused weight and bias params, and replacing all the users
131108 # of getitem(batchnorm) with the linear op.
132109
133- with graph .inserting_before (linear_weight_arg ):
110+ with graph .inserting_before (linear . args [ 1 ] ):
134111 fused_linear_weight_node = create_constant_placeholder (
135112 exp_program = self .exported_program ,
136113 graph = graph_module .graph ,
@@ -149,20 +126,11 @@ def call(self, graph_module: torch.fx.GraphModule):
149126 else :
150127 fused_linear_bias_node = None
151128
152- if linear .target == exir_ops .edge .aten .addmm .default :
153- # addmm.args = (bias, input, weight)
154- linear .args = (
155- fused_linear_bias_node ,
156- linear_input_arg ,
157- fused_linear_weight_node ,
158- )
159- else :
160- # linear.args = (input, weight, bias)
161- linear .args = (
162- linear_input_arg ,
163- fused_linear_weight_node ,
164- fused_linear_bias_node ,
165- )
129+ linear .args = (
130+ linear .args [0 ],
131+ fused_linear_weight_node ,
132+ fused_linear_bias_node ,
133+ )
166134
167135 # Remove any use of batchnorm from the graph
168136 for user in bn .users .copy ():
@@ -187,7 +155,7 @@ def call(self, graph_module: torch.fx.GraphModule):
187155
188156 @staticmethod
189157 def can_fuse (
190- linear_weights : torch .fx .Node ,
158+ linear : torch .fx .Node ,
191159 bn : torch .fx .Node ,
192160 program : ExportedProgram ,
193161 ) -> bool :
@@ -206,23 +174,11 @@ def can_fuse(
206174 bn_weights = bn .args [1 ]
207175
208176 # Check that the weights for linear and batchnorm are both params
209- if not isinstance (linear_weights , torch .fx .Node ) or not isinstance (
177+ if not isinstance (linear , torch .fx .Node ) or not isinstance (
210178 bn_weights , torch .fx .Node
211179 ):
212180 return False
213181
214- if [
215- is_param_node (program , node ) for node in {linear_weights , bn_weights }
216- ].count (False ):
182+ if [is_param_node (program , node ) for node in {linear , bn_weights }].count (False ):
217183 return False
218-
219184 return True
220-
221- @staticmethod
222- def _unwrap_node (node : torch .fx .Node ) -> torch .fx .Node :
223- while node .op == "call_function" and node .target in {
224- exir_ops .edge .aten .permute .default ,
225- exir_ops .edge .aten .permute_copy .default ,
226- }:
227- node = node .args [0 ]
228- return node
0 commit comments