Skip to content

Commit 9bd8064

Browse files
committed
Merge conv and linear fusion passes into FuseBatchNormPass
1 parent 8003ba3 commit 9bd8064

File tree

6 files changed

+240
-411
lines changed

6 files changed

+240
-411
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,7 @@
2121
)
2222
from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate
2323
from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass
24-
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
25-
FuseBatchNormWithConvPass,
26-
)
27-
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import (
28-
FuseBatchNormWithLinearPass,
29-
)
24+
from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass
3025
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
3126
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
3227
TagImplicitQDqPass,
@@ -66,8 +61,7 @@ def __init__(
6661
ConvertToLinearPass,
6762
ConvertToSDPAPass,
6863
ConstPropPass,
69-
FuseBatchNormWithConvPass,
70-
FuseBatchNormWithLinearPass,
64+
FuseBatchNormPass,
7165
FuseActivationPass,
7266
DecomposeConcatenate,
7367
RemoveGetItemPass,
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import operator
8+
9+
import torch
10+
from executorch.backends.transforms.utils import (
11+
create_constant_placeholder,
12+
delete_constant_placeholder,
13+
)
14+
15+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
16+
17+
from executorch.backends.xnnpack.utils.utils import (
18+
get_param_tensor,
19+
get_tensor_name,
20+
is_param_node,
21+
)
22+
from executorch.exir import ExportedProgram
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from executorch.exir.pass_base import PassResult
25+
from torch.export.graph_signature import InputKind
26+
27+
from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights
28+
29+
30+
class FuseBatchNormPass(XNNPACKPass):
31+
"""
32+
BatchNorm can be implemented using 1x1 Depthwise Convolution. However, doing so will increase
33+
memory usage since we serialize new weights to represent the convolution. In most cases,
34+
BatchNorm is used after convolution or linear. The 1x1 depthwise convolution can then be fused
35+
with the previous convolution. For linear cases, BatchNorm can be folded into the previous linear layer.
36+
"""
37+
38+
def call(self, graph_module: torch.fx.GraphModule):
39+
graph = graph_module.graph
40+
constant_placeholders_to_delete = set()
41+
for node in graph.nodes:
42+
# We want to discover a chain of conv -> batch_norm or linear -> batch_norm.
43+
# Only proceed if the current node is a conv or linear node, and has a single
44+
# user/successor.
45+
is_conv = node.target == exir_ops.edge.aten.convolution.default
46+
is_linear = node.target == exir_ops.edge.aten.linear.default
47+
48+
if not (is_conv or is_linear):
49+
continue
50+
if len(node.users) != 1:
51+
continue
52+
53+
# Conv or linear op to fuse.
54+
target_op = node
55+
56+
# The single user of the op must be batch_norm. If not, bail.
57+
bn = list(target_op.users.keys())[0]
58+
if (
59+
bn.target != exir_ops.edge.aten.native_batch_norm.default
60+
and bn.target
61+
!= exir_ops.edge.aten._native_batch_norm_legit_no_training.default
62+
):
63+
continue
64+
65+
if not self.can_fuse(target_op, bn, self.exported_program):
66+
continue
67+
68+
self._fuse_ops(
69+
graph_module,
70+
graph,
71+
target_op,
72+
bn,
73+
is_conv,
74+
constant_placeholders_to_delete,
75+
)
76+
77+
if len(constant_placeholders_to_delete) > 0:
78+
graph_module.graph.eliminate_dead_code()
79+
for node in constant_placeholders_to_delete:
80+
if (node is not None) and (len(node.users) == 0):
81+
delete_constant_placeholder(self.exported_program, node)
82+
83+
graph_module.recompile()
84+
# To Regenerate metadata and shape information, retrace module.
85+
graph_module = super().call(graph_module).graph_module
86+
87+
return PassResult(graph_module, True)
88+
89+
@staticmethod
90+
def can_fuse(
91+
target_op: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram
92+
) -> bool:
93+
"""
94+
Determine whether a batchnorm node can be fused with a preceding conv or linear node.
95+
"""
96+
97+
# All the users of batchnorm node must be getitem ops. batchnorm
98+
# returns a 3-element tuple. Each user must only access the first
99+
# element of the tuple.
100+
if [
101+
(user.target == operator.getitem and user.args[1] == 0) for user in bn.users
102+
].count(False):
103+
return False
104+
105+
target_op_weights = target_op.args[1]
106+
bn_weights = bn.args[1]
107+
108+
# Check that the weights for conv or linear and batchnorm are both params.
109+
if not isinstance(target_op_weights, torch.fx.Node) or not isinstance(
110+
bn_weights, torch.fx.Node
111+
):
112+
return False
113+
114+
if [
115+
is_param_node(program, node) for node in {target_op_weights, bn_weights}
116+
].count(False):
117+
return False
118+
119+
return True
120+
121+
def _fuse_ops(
122+
self,
123+
graph_module: torch.fx.GraphModule,
124+
graph: torch.fx.Graph,
125+
target_op: torch.fx.Node,
126+
bn: torch.fx.Node,
127+
is_conv: bool,
128+
constant_placeholders_to_delete: set,
129+
) -> None:
130+
"""
131+
Fuse a BatchNorm into the preceding conv or linear op.
132+
Update the fused op's weight and bias, rewire users of the BatchNorm's output, and remove the BatchNorm node.
133+
"""
134+
135+
if is_conv:
136+
assert len(target_op.args) == 9
137+
else: # Linear path: (input, weight, bias).
138+
assert len(target_op.args) == 3
139+
140+
# Get the weight and bias parameters from the conv or linear op.
141+
target_op_weight = get_param_tensor(self.exported_program, target_op.args[1])
142+
target_op_weight_name = get_tensor_name(
143+
self.exported_program, target_op.args[1]
144+
)
145+
assert target_op_weight is not None
146+
147+
target_op_bias = get_param_tensor(self.exported_program, target_op.args[2])
148+
target_op_bias_name = get_tensor_name(self.exported_program, target_op.args[2])
149+
150+
# Get the parameters from the batchnorm op.
151+
assert (
152+
bn.target == exir_ops.edge.aten.native_batch_norm.default
153+
and len(bn.args) == 8
154+
) or (
155+
bn.target == exir_ops.edge.aten._native_batch_norm_legit_no_training.default
156+
and len(bn.args) == 7
157+
)
158+
bn_weight = get_param_tensor(self.exported_program, bn.args[1])
159+
bn_bias = get_param_tensor(self.exported_program, bn.args[2])
160+
161+
running_mean = get_param_tensor(self.exported_program, bn.args[3])
162+
assert running_mean is not None
163+
164+
running_var = get_param_tensor(self.exported_program, bn.args[4])
165+
assert running_var is not None
166+
167+
# args[7] for native_batch_norm, but args[6] for
168+
# _native_batch_norm_legit_no_training (which doesn't have training
169+
# as an arg).
170+
eps = bn.args[-1]
171+
172+
# Compute the updated weight and bias after fusing conv or linear op with batchnorm op.
173+
fuse_args = (
174+
target_op_weight,
175+
target_op_bias,
176+
running_mean,
177+
running_var,
178+
eps,
179+
bn_weight,
180+
bn_bias,
181+
)
182+
183+
if is_conv:
184+
is_transpose = target_op.args[6]
185+
fused_weight, fused_bias = fuse_conv_bn_weights(*fuse_args, is_transpose)
186+
else: # Linear path.
187+
fused_weight, fused_bias = fuse_linear_bn_weights(*fuse_args)
188+
189+
fused_weight_name = (target_op_weight_name + "_fused_bn").replace(".", "_")
190+
if target_op_bias_name == "":
191+
fused_bias_name = (target_op_weight_name + "_bias_fused_bn").replace(
192+
".", "_"
193+
)
194+
else:
195+
fused_bias_name = (target_op_bias_name + "_fused_bn").replace(".", "_")
196+
197+
# Modify the graph by updating the weight and bias of conv or linear op
198+
# with the fused weight and bias params, and replacing all the users
199+
# of getitem(batchnorm) with the conv or linear op.
200+
with graph.inserting_before(target_op.args[1]):
201+
fused_op_weight_node = create_constant_placeholder(
202+
exp_program=self.exported_program,
203+
graph=graph_module.graph,
204+
kind=InputKind.PARAMETER,
205+
name=fused_weight_name,
206+
data=fused_weight,
207+
)
208+
if fused_bias is not None:
209+
fused_op_bias_node = create_constant_placeholder(
210+
exp_program=self.exported_program,
211+
graph=graph_module.graph,
212+
kind=InputKind.PARAMETER,
213+
name=fused_bias_name,
214+
data=fused_bias,
215+
)
216+
else:
217+
fused_op_bias_node = None
218+
219+
# Replace weight and bias with the fused batchnorm values.
220+
args = list(target_op.args)
221+
args[1] = fused_op_weight_node
222+
args[2] = fused_op_bias_node
223+
target_op.args = tuple(args)
224+
225+
# Remove any use of batchnorm from the graph
226+
for user in bn.users.copy():
227+
assert user.target == operator.getitem
228+
user.replace_all_uses_with(target_op)
229+
graph.erase_node(user)
230+
231+
graph.erase_node(bn)
232+
constant_placeholders_to_delete.update(target_op.args[1:3] + bn.args[1:5])

0 commit comments

Comments
 (0)