@@ -38,37 +38,31 @@ class FuseBatchNormPass(XNNPACKPass):
3838 def call (self , graph_module : torch .fx .GraphModule ):
3939 graph = graph_module .graph
4040 constant_placeholders_to_delete = set ()
41- for node in graph .nodes :
41+ for input_node in graph .nodes :
4242 # We want to discover a chain of conv -> batch_norm or linear -> batch_norm.
43- # Only proceed if the current node is a conv or linear node, and has a single
44- # user/successor.
45- is_conv = node .target == exir_ops .edge .aten .convolution .default
46- is_linear = node .target == exir_ops .edge .aten .linear .default
43+ # Only proceed if the current node is a conv or linear, and has a single user/successor.
44+ is_conv = input_node .target == exir_ops .edge .aten .convolution .default
45+ is_linear = input_node .target == exir_ops .edge .aten .linear .default
4746
48- if not (is_conv or is_linear ):
47+ if not (is_conv or is_linear ) or len ( input_node . users ) != 1 :
4948 continue
50- if len (node .users ) != 1 :
51- continue
52-
53- # Conv or linear op to fuse.
54- target_op = node
5549
56- # The single user of the op must be batch_norm. If not, bail.
57- bn = list (target_op .users .keys ())[0 ]
50+ # The single user of the conv or linear node must be batch_norm. If not, bail.
51+ bn = list (input_node .users .keys ())[0 ]
5852 if (
5953 bn .target != exir_ops .edge .aten .native_batch_norm .default
6054 and bn .target
6155 != exir_ops .edge .aten ._native_batch_norm_legit_no_training .default
6256 ):
6357 continue
6458
65- if not self .can_fuse (target_op , bn , self .exported_program ):
59+ if not self .can_fuse (input_node , bn , self .exported_program ):
6660 continue
6761
6862 self ._fuse_ops (
6963 graph_module ,
7064 graph ,
71- target_op ,
65+ input_node ,
7266 bn ,
7367 is_conv ,
7468 constant_placeholders_to_delete ,
@@ -81,38 +75,38 @@ def call(self, graph_module: torch.fx.GraphModule):
8175 delete_constant_placeholder (self .exported_program , node )
8276
8377 graph_module .recompile ()
84- # To Regenerate metadata and shape information, retrace module.
78+ # To regenerate metadata and shape information, retrace module.
8579 graph_module = super ().call (graph_module ).graph_module
8680
8781 return PassResult (graph_module , True )
8882
8983 @staticmethod
9084 def can_fuse (
91- target_op : torch .fx .Node , bn : torch .fx .Node , program : ExportedProgram
85+ input_node : torch .fx .Node , bn : torch .fx .Node , program : ExportedProgram
9286 ) -> bool :
9387 """
94- Determine whether a batchnorm node can be fused with a preceding conv or linear node.
88+ Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
9589 """
9690
97- # All the users of batchnorm node must be getitem ops. batchnorm
98- # returns a 3-element tuple. Each user must only access the first
99- # element of the tuple.
91+ # All users of the batch_norm node must be getitem ops.
92+ # batch_norm returns a 3-element tuple.
93+ # Each user must only access the first element of the tuple.
10094 if [
10195 (user .target == operator .getitem and user .args [1 ] == 0 ) for user in bn .users
10296 ].count (False ):
10397 return False
10498
105- target_op_weights = target_op .args [1 ]
99+ input_node_weights = input_node .args [1 ]
106100 bn_weights = bn .args [1 ]
107101
108- # Check that the weights for conv or linear and batchnorm are both params.
109- if not isinstance (target_op_weights , torch .fx .Node ) or not isinstance (
102+ # Check that the weights for conv or linear and batch_norm are both params.
103+ if not isinstance (input_node_weights , torch .fx .Node ) or not isinstance (
110104 bn_weights , torch .fx .Node
111105 ):
112106 return False
113107
114108 if [
115- is_param_node (program , node ) for node in {target_op_weights , bn_weights }
109+ is_param_node (program , node ) for node in {input_node_weights , bn_weights }
116110 ].count (False ):
117111 return False
118112
@@ -122,32 +116,45 @@ def _fuse_ops(
122116 self ,
123117 graph_module : torch .fx .GraphModule ,
124118 graph : torch .fx .Graph ,
125- target_op : torch .fx .Node ,
119+ input_node : torch .fx .Node ,
126120 bn : torch .fx .Node ,
127121 is_conv : bool ,
128122 constant_placeholders_to_delete : set ,
129123 ) -> None :
130124 """
131- Fuse a BatchNorm into the preceding conv or linear op.
132- Update the fused op's weight and bias, rewire users of the BatchNorm's output, and remove the BatchNorm node.
125+ Fuse a BatchNorm node into the preceding convolution or linear node.
126+ Update the fused node's weight and bias, rewire users of the BatchNorm output,
127+ and remove the BatchNorm node.
133128 """
134129
135130 if is_conv :
136- assert len (target_op .args ) == 9
137- else : # Linear path: (input, weight, bias).
138- assert len (target_op .args ) == 3
131+ assert len (input_node .args ) == 9
132+ has_bias_arg = True
133+ else :
134+ # Otherwise, this is a linear node.
135+ # Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias).
136+ assert len (input_node .args ) in (2 , 3 )
137+ has_bias_arg = len (input_node .args ) == 3
139138
140139 # Get the weight and bias parameters from the conv or linear op.
141- target_op_weight = get_param_tensor (self .exported_program , target_op .args [1 ])
142- target_op_weight_name = get_tensor_name (
143- self .exported_program , target_op .args [1 ]
140+ input_node_weight = get_param_tensor (self .exported_program , input_node .args [1 ])
141+ input_node_weight_name = get_tensor_name (
142+ self .exported_program , input_node .args [1 ]
144143 )
145- assert target_op_weight is not None
144+ assert input_node_weight is not None
146145
147- target_op_bias = get_param_tensor (self .exported_program , target_op .args [2 ])
148- target_op_bias_name = get_tensor_name (self .exported_program , target_op .args [2 ])
146+ if has_bias_arg :
147+ input_node_bias = get_param_tensor (
148+ self .exported_program , input_node .args [2 ]
149+ )
150+ input_node_bias_name = get_tensor_name (
151+ self .exported_program , input_node .args [2 ]
152+ )
153+ else :
154+ input_node_bias = None
155+ input_node_bias_name = ""
149156
150- # Get the parameters from the batchnorm op.
157+ # Get the parameters from the batch_norm op.
151158 assert (
152159 bn .target == exir_ops .edge .aten .native_batch_norm .default
153160 and len (bn .args ) == 8
@@ -169,10 +176,10 @@ def _fuse_ops(
169176 # as an arg).
170177 eps = bn .args [- 1 ]
171178
172- # Compute the updated weight and bias after fusing conv or linear op with batchnorm op.
179+ # Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op.
173180 fuse_args = (
174- target_op_weight ,
175- target_op_bias ,
181+ input_node_weight ,
182+ input_node_bias ,
176183 running_mean ,
177184 running_var ,
178185 eps ,
@@ -181,23 +188,24 @@ def _fuse_ops(
181188 )
182189
183190 if is_conv :
184- is_transpose = target_op .args [6 ]
191+ is_transpose = input_node .args [6 ]
185192 fused_weight , fused_bias = fuse_conv_bn_weights (* fuse_args , is_transpose )
186- else : # Linear path.
193+ else :
194+ # Otherwise, this is a linear node.
187195 fused_weight , fused_bias = fuse_linear_bn_weights (* fuse_args )
188196
189- fused_weight_name = (target_op_weight_name + "_fused_bn" ).replace ("." , "_" )
190- if target_op_bias_name == "" :
191- fused_bias_name = (target_op_weight_name + "_bias_fused_bn" ).replace (
197+ fused_weight_name = (input_node_weight_name + "_fused_bn" ).replace ("." , "_" )
198+ if input_node_bias_name == "" :
199+ fused_bias_name = (input_node_weight_name + "_bias_fused_bn" ).replace (
192200 "." , "_"
193201 )
194202 else :
195- fused_bias_name = (target_op_bias_name + "_fused_bn" ).replace ("." , "_" )
203+ fused_bias_name = (input_node_bias_name + "_fused_bn" ).replace ("." , "_" )
196204
197- # Modify the graph by updating the weight and bias of conv or linear op
205+ # Modify the graph by updating the weight and bias of the conv or linear op
198206 # with the fused weight and bias params, and replacing all the users
199- # of getitem(batchnorm ) with the conv or linear op.
200- with graph .inserting_before (target_op .args [1 ]):
207+ # of getitem(batch_norm ) with the conv or linear op.
208+ with graph .inserting_before (input_node .args [1 ]):
201209 fused_op_weight_node = create_constant_placeholder (
202210 exp_program = self .exported_program ,
203211 graph = graph_module .graph ,
@@ -216,17 +224,24 @@ def _fuse_ops(
216224 else :
217225 fused_op_bias_node = None
218226
219- # Replace weight and bias with the fused batchnorm values.
220- args = list (target_op .args )
227+ # Replace the original weight and bias with the fused batch_norm values.
228+ args = list (input_node .args )
221229 args [1 ] = fused_op_weight_node
222- args [2 ] = fused_op_bias_node
223- target_op .args = tuple (args )
224230
225- # Remove any use of batchnorm from the graph
231+ if has_bias_arg :
232+ # Overwrite original bias with the fused bias.
233+ args [2 ] = fused_op_bias_node
234+ elif fused_op_bias_node is not None :
235+ # Add the fused bias as a new argument if no bias had originally existed in the input_node.
236+ args .append (fused_op_bias_node )
237+
238+ input_node .args = tuple (args )
239+
240+ # Remove any use of batch_norm from the graph.
226241 for user in bn .users .copy ():
227242 assert user .target == operator .getitem
228- user .replace_all_uses_with (target_op )
243+ user .replace_all_uses_with (input_node )
229244 graph .erase_node (user )
230245
231246 graph .erase_node (bn )
232- constant_placeholders_to_delete .update (target_op .args [1 :3 ] + bn .args [1 :5 ])
247+ constant_placeholders_to_delete .update (input_node .args [1 :3 ] + bn .args [1 :5 ])
0 commit comments