Skip to content

Commit 20afaa8

Browse files
committed
Enable BatchNorm fusion for Linear with bias=False
1 parent 9bd8064 commit 20afaa8

File tree

2 files changed

+84
-68
lines changed

2 files changed

+84
-68
lines changed

backends/xnnpack/_passes/fuse_batch_norm.py

Lines changed: 72 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -38,37 +38,31 @@ class FuseBatchNormPass(XNNPACKPass):
3838
def call(self, graph_module: torch.fx.GraphModule):
3939
graph = graph_module.graph
4040
constant_placeholders_to_delete = set()
41-
for node in graph.nodes:
41+
for input_node in graph.nodes:
4242
# We want to discover a chain of conv -> batch_norm or linear -> batch_norm.
43-
# Only proceed if the current node is a conv or linear node, and has a single
44-
# user/successor.
45-
is_conv = node.target == exir_ops.edge.aten.convolution.default
46-
is_linear = node.target == exir_ops.edge.aten.linear.default
43+
# Only proceed if the current node is a conv or linear, and has a single user/successor.
44+
is_conv = input_node.target == exir_ops.edge.aten.convolution.default
45+
is_linear = input_node.target == exir_ops.edge.aten.linear.default
4746

48-
if not (is_conv or is_linear):
47+
if not (is_conv or is_linear) or len(input_node.users) != 1:
4948
continue
50-
if len(node.users) != 1:
51-
continue
52-
53-
# Conv or linear op to fuse.
54-
target_op = node
5549

56-
# The single user of the op must be batch_norm. If not, bail.
57-
bn = list(target_op.users.keys())[0]
50+
# The single user of the conv or linear node must be batch_norm. If not, bail.
51+
bn = list(input_node.users.keys())[0]
5852
if (
5953
bn.target != exir_ops.edge.aten.native_batch_norm.default
6054
and bn.target
6155
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
6256
):
6357
continue
6458

65-
if not self.can_fuse(target_op, bn, self.exported_program):
59+
if not self.can_fuse(input_node, bn, self.exported_program):
6660
continue
6761

6862
self._fuse_ops(
6963
graph_module,
7064
graph,
71-
target_op,
65+
input_node,
7266
bn,
7367
is_conv,
7468
constant_placeholders_to_delete,
@@ -81,38 +75,38 @@ def call(self, graph_module: torch.fx.GraphModule):
8175
delete_constant_placeholder(self.exported_program, node)
8276

8377
graph_module.recompile()
84-
# To Regenerate metadata and shape information, retrace module.
78+
# To regenerate metadata and shape information, retrace module.
8579
graph_module = super().call(graph_module).graph_module
8680

8781
return PassResult(graph_module, True)
8882

8983
@staticmethod
9084
def can_fuse(
91-
target_op: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
85+
input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
9286
) -> bool:
9387
"""
94-
Determine whether a batchnorm node can be fused with a preceding conv or linear node.
88+
Determine whether a BatchNorm node can be fused with the preceding convolution or linear node.
9589
"""
9690

97-
# All the users of batchnorm node must be getitem ops. batchnorm
98-
# returns a 3-element tuple. Each user must only access the first
99-
# element of the tuple.
91+
# All users of the batch_norm node must be getitem ops.
92+
# batch_norm returns a 3-element tuple.
93+
# Each user must only access the first element of the tuple.
10094
if [
10195
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
10296
].count(False):
10397
return False
10498

105-
target_op_weights = target_op.args[1]
99+
input_node_weights = input_node.args[1]
106100
bn_weights = bn.args[1]
107101

108-
# Check that the weights for conv or linear and batchnorm are both params.
109-
if not isinstance(target_op_weights, torch.fx.Node) or not isinstance(
102+
# Check that the weights for conv or linear and batch_norm are both params.
103+
if not isinstance(input_node_weights, torch.fx.Node) or not isinstance(
110104
bn_weights, torch.fx.Node
111105
):
112106
return False
113107

114108
if [
115-
is_param_node(program, node) for node in {target_op_weights, bn_weights}
109+
is_param_node(program, node) for node in {input_node_weights, bn_weights}
116110
].count(False):
117111
return False
118112

@@ -122,32 +116,45 @@ def _fuse_ops(
122116
self,
123117
graph_module: torch.fx.GraphModule,
124118
graph: torch.fx.Graph,
125-
target_op: torch.fx.Node,
119+
input_node: torch.fx.Node,
126120
bn: torch.fx.Node,
127121
is_conv: bool,
128122
constant_placeholders_to_delete: set,
129123
) -> None:
130124
"""
131-
Fuse a BatchNorm into the preceding conv or linear op.
132-
Update the fused op's weight and bias, rewire users of the BatchNorm's output, and remove the BatchNorm node.
125+
Fuse a BatchNorm node into the preceding convolution or linear node.
126+
Update the fused node's weight and bias, rewire users of the BatchNorm output,
127+
and remove the BatchNorm node.
133128
"""
134129

135130
if is_conv:
136-
assert len(target_op.args) == 9
137-
else: # Linear path: (input, weight, bias).
138-
assert len(target_op.args) == 3
131+
assert len(input_node.args) == 9
132+
has_bias_arg = True
133+
else:
134+
# Otherwise, this is a linear node.
135+
# Linear has 2 or 3 args depending on whether bias is used: (input, weight, bias).
136+
assert len(input_node.args) in (2, 3)
137+
has_bias_arg = len(input_node.args) == 3
139138

140139
# Get the weight and bias parameters from the conv or linear op.
141-
target_op_weight = get_param_tensor(self.exported_program, target_op.args[1])
142-
target_op_weight_name = get_tensor_name(
143-
self.exported_program, target_op.args[1]
140+
input_node_weight = get_param_tensor(self.exported_program, input_node.args[1])
141+
input_node_weight_name = get_tensor_name(
142+
self.exported_program, input_node.args[1]
144143
)
145-
assert target_op_weight is not None
144+
assert input_node_weight is not None
146145

147-
target_op_bias = get_param_tensor(self.exported_program, target_op.args[2])
148-
target_op_bias_name = get_tensor_name(self.exported_program, target_op.args[2])
146+
if has_bias_arg:
147+
input_node_bias = get_param_tensor(
148+
self.exported_program, input_node.args[2]
149+
)
150+
input_node_bias_name = get_tensor_name(
151+
self.exported_program, input_node.args[2]
152+
)
153+
else:
154+
input_node_bias = None
155+
input_node_bias_name = ""
149156

150-
# Get the parameters from the batchnorm op.
157+
# Get the parameters from the batch_norm op.
151158
assert (
152159
bn.target == exir_ops.edge.aten.native_batch_norm.default
153160
and len(bn.args) == 8
@@ -169,10 +176,10 @@ def _fuse_ops(
169176
# as an arg).
170177
eps = bn.args[-1]
171178

172-
# Compute the updated weight and bias after fusing conv or linear op with batchnorm op.
179+
# Compute the updated weight and bias after fusing the conv or linear op with the batch_norm op.
173180
fuse_args = (
174-
target_op_weight,
175-
target_op_bias,
181+
input_node_weight,
182+
input_node_bias,
176183
running_mean,
177184
running_var,
178185
eps,
@@ -181,23 +188,24 @@ def _fuse_ops(
181188
)
182189

183190
if is_conv:
184-
is_transpose = target_op.args[6]
191+
is_transpose = input_node.args[6]
185192
fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose)
186-
else: # Linear path.
193+
else:
194+
# Otherwise, this is a linear node.
187195
fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args)
188196

189-
fused_weight_name = (target_op_weight_name + "_fused_bn").replace(".", "_")
190-
if target_op_bias_name == "":
191-
fused_bias_name = (target_op_weight_name + "_bias_fused_bn").replace(
197+
fused_weight_name = (input_node_weight_name + "_fused_bn").replace(".", "_")
198+
if input_node_bias_name == "":
199+
fused_bias_name = (input_node_weight_name + "_bias_fused_bn").replace(
192200
".", "_"
193201
)
194202
else:
195-
fused_bias_name = (target_op_bias_name + "_fused_bn").replace(".", "_")
203+
fused_bias_name = (input_node_bias_name + "_fused_bn").replace(".", "_")
196204

197-
# Modify the graph by updating the weight and bias of conv or linear op
205+
# Modify the graph by updating the weight and bias of the conv or linear op
198206
# with the fused weight and bias params, and replacing all the users
199-
# of getitem(batchnorm) with the conv or linear op.
200-
with graph.inserting_before(target_op.args[1]):
207+
# of getitem(batch_norm) with the conv or linear op.
208+
with graph.inserting_before(input_node.args[1]):
201209
fused_op_weight_node = create_constant_placeholder(
202210
exp_program=self.exported_program,
203211
graph=graph_module.graph,
@@ -216,17 +224,24 @@ def _fuse_ops(
216224
else:
217225
fused_op_bias_node = None
218226

219-
# Replace weight and bias with the fused batchnorm values.
220-
args = list(target_op.args)
227+
# Replace the original weight and bias with the fused batch_norm values.
228+
args = list(input_node.args)
221229
args[1] = fused_op_weight_node
222-
args[2] = fused_op_bias_node
223-
target_op.args = tuple(args)
224230

225-
# Remove any use of batchnorm from the graph
231+
if has_bias_arg:
232+
# Overwrite original bias with the fused bias.
233+
args[2] = fused_op_bias_node
234+
elif fused_op_bias_node is not None:
235+
# Add the fused bias as a new argument if no bias had originally existed in the input_node.
236+
args.append(fused_op_bias_node)
237+
238+
input_node.args = tuple(args)
239+
240+
# Remove any use of batch_norm from the graph.
226241
for user in bn.users.copy():
227242
assert user.target == operator.getitem
228-
user.replace_all_uses_with(target_op)
243+
user.replace_all_uses_with(input_node)
229244
graph.erase_node(user)
230245

231246
graph.erase_node(bn)
232-
constant_placeholders_to_delete.update(target_op.args[1:3] + bn.args[1:5])
247+
constant_placeholders_to_delete.update(input_node.args[1:3] + bn.args[1:5])

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def forward(self, x):
4141
return self.bn(y)
4242

4343
class ModelLinearBN(torch.nn.Module):
44-
def __init__(self, in_features, out_features):
44+
def __init__(self, in_features, out_features, bias=True):
4545
super().__init__()
4646
op = torch.nn.Linear
47-
self.linear = op(in_features, out_features)
47+
self.linear = op(in_features, out_features, bias=bias)
4848
self.bn = torch.nn.BatchNorm1d(out_features)
4949
self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats
5050

@@ -109,16 +109,17 @@ def forward(self, x):
109109
)
110110

111111
def test_fp32_linear_batch_norm_fusion(self):
112-
(
113-
Tester(
114-
self.ModelLinearBN(2, 2).eval(),
115-
(torch.randn(2, 2),),
112+
for bias in [True, False]:
113+
(
114+
Tester(
115+
self.ModelLinearBN(2, 2, bias).eval(),
116+
(torch.randn(2, 2),),
117+
)
118+
.export()
119+
.to_edge_transform_and_lower()
120+
.check_count({self.bn_name: 1})
121+
.run_method_and_compare_outputs()
116122
)
117-
.export()
118-
.to_edge_transform_and_lower()
119-
.check_count({self.bn_name: 1})
120-
.run_method_and_compare_outputs()
121-
)
122123

123124
def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self):
124125
"""

0 commit comments

Comments
 (0)