Skip to content

Commit 98a954c

Browse files
committed
Remove addmm node check from linear pass and combine conv/linear fusion test files
1 parent 2aefc75 commit 98a954c

File tree

3 files changed

+79
-155
lines changed

3 files changed

+79
-155
lines changed

backends/xnnpack/_passes/fuse_batch_norm_with_linear.py

Lines changed: 18 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@ def call(self, graph_module: torch.fx.GraphModule):
3232
graph = graph_module.graph
3333
constant_placeholders_to_delete = set()
3434
for linear in graph.nodes:
35-
# We want to discover a chain of linear -> batch_norm or addmm -> batch_norm.
36-
# Only proceed if the current node is a linear or addmm node, and has a single
35+
# We want to discover a chain of linear -> batch_norm.
36+
# Only proceed if the current node is a linear node, and has a single
3737
# user/successor.
3838
if (
3939
linear.target != exir_ops.edge.aten.linear.default
40-
and linear.target != exir_ops.edge.aten.addmm.default
4140
or len(linear.users) != 1
4241
):
4342
continue
@@ -51,34 +50,18 @@ def call(self, graph_module: torch.fx.GraphModule):
5150
):
5251
continue
5352

53+
if not self.can_fuse(linear, bn, self.exported_program):
54+
continue
55+
5456
# Get the parameters
5557
assert len(linear.args) == 3
5658

57-
if linear.target == exir_ops.edge.aten.addmm.default:
58-
# addmm.args = (bias, input, weight)
59-
linear_bias_arg = linear.args[0]
60-
linear_input_arg = linear.args[1]
61-
# Unwrap permute_copy to access weight parameter node
62-
linear_weight_arg = FuseBatchNormWithLinearPass._unwrap_node(
63-
linear.args[2]
64-
)
65-
else:
66-
# linear.args = (input, weight, bias)
67-
linear_input_arg = linear.args[0]
68-
linear_weight_arg = linear.args[1]
69-
linear_bias_arg = linear.args[2]
70-
71-
if not self.can_fuse(linear_weight_arg, bn, self.exported_program):
72-
continue
73-
74-
linear_weight = get_param_tensor(self.exported_program, linear_weight_arg)
75-
linear_weight_name = get_tensor_name(
76-
self.exported_program, linear_weight_arg
77-
)
59+
linear_weight = get_param_tensor(self.exported_program, linear.args[1])
60+
linear_weight_name = get_tensor_name(self.exported_program, linear.args[1])
7861
assert linear_weight is not None
7962

80-
linear_bias = get_param_tensor(self.exported_program, linear_bias_arg)
81-
linear_bias_name = get_tensor_name(self.exported_program, linear_bias_arg)
63+
linear_bias = get_param_tensor(self.exported_program, linear.args[2])
64+
linear_bias_name = get_tensor_name(self.exported_program, linear.args[2])
8265

8366
# Get the parameters from the batchnorm op
8467
assert (
@@ -112,12 +95,6 @@ def call(self, graph_module: torch.fx.GraphModule):
11295
bn_weight,
11396
bn_bias,
11497
)
115-
116-
if linear.target == exir_ops.edge.aten.addmm.default:
117-
# fuse_linear_bn_weights returns weight [out, in];
118-
# permute_copy node was removed, so weight must be transposed to [in, out] for addmm
119-
fused_weight = fused_weight.t()
120-
12198
fused_weight_name = (linear_weight_name + "_fused_bn").replace(".", "_")
12299
if linear_bias_name == "":
123100
fused_bias_name = (linear_weight_name + "_bias_fused_bn").replace(
@@ -130,7 +107,7 @@ def call(self, graph_module: torch.fx.GraphModule):
130107
# with the fused weight and bias params, and replacing all the users
131108
# of getitem(batchnorm) with the linear op.
132109

133-
with graph.inserting_before(linear_weight_arg):
110+
with graph.inserting_before(linear.args[1]):
134111
fused_linear_weight_node = create_constant_placeholder(
135112
exp_program=self.exported_program,
136113
graph=graph_module.graph,
@@ -149,20 +126,11 @@ def call(self, graph_module: torch.fx.GraphModule):
149126
else:
150127
fused_linear_bias_node = None
151128

152-
if linear.target == exir_ops.edge.aten.addmm.default:
153-
# addmm.args = (bias, input, weight)
154-
linear.args = (
155-
fused_linear_bias_node,
156-
linear_input_arg,
157-
fused_linear_weight_node,
158-
)
159-
else:
160-
# linear.args = (input, weight, bias)
161-
linear.args = (
162-
linear_input_arg,
163-
fused_linear_weight_node,
164-
fused_linear_bias_node,
165-
)
129+
linear.args = (
130+
linear.args[0],
131+
fused_linear_weight_node,
132+
fused_linear_bias_node,
133+
)
166134

167135
# Remove any use of batchnorm from the graph
168136
for user in bn.users.copy():
@@ -187,7 +155,7 @@ def call(self, graph_module: torch.fx.GraphModule):
187155

188156
@staticmethod
189157
def can_fuse(
190-
linear_weights: torch.fx.Node,
158+
linear: torch.fx.Node,
191159
bn: torch.fx.Node,
192160
program: ExportedProgram,
193161
) -> bool:
@@ -206,23 +174,11 @@ def can_fuse(
206174
bn_weights = bn.args[1]
207175

208176
# Check that the weights for linear and batchnorm are both params
209-
if not isinstance(linear_weights, torch.fx.Node) or not isinstance(
177+
if not isinstance(linear, torch.fx.Node) or not isinstance(
210178
bn_weights, torch.fx.Node
211179
):
212180
return False
213181

214-
if [
215-
is_param_node(program, node) for node in {linear_weights, bn_weights}
216-
].count(False):
182+
if [is_param_node(program, node) for node in {linear, bn_weights}].count(False):
217183
return False
218-
219184
return True
220-
221-
@staticmethod
222-
def _unwrap_node(node: torch.fx.Node) -> torch.fx.Node:
223-
while node.op == "call_function" and node.target in {
224-
exir_ops.edge.aten.permute.default,
225-
exir_ops.edge.aten.permute_copy.default,
226-
}:
227-
node = node.args[0]
228-
return node

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
1212
FuseBatchNormWithConvPass,
1313
)
14+
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import (
15+
FuseBatchNormWithLinearPass,
16+
)
1417
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
1518

1619

1720
class TestBatchNormFusion(unittest.TestCase):
18-
PassStage = RunPasses([FuseBatchNormWithConvPass])
21+
ConvPassStage = RunPasses([FuseBatchNormWithConvPass])
22+
LinearPassStage = RunPasses([FuseBatchNormWithLinearPass])
1923
bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"
2024

2125
def setUp(self):
@@ -42,7 +46,22 @@ def forward(self, x):
4246
y = y + y
4347
return self.bn(y)
4448

45-
def test_fp32_batch_norm_fusion(self):
49+
class ModelLinearBN(torch.nn.Module):
50+
def __init__(self, in_features, out_features):
51+
super().__init__()
52+
op = torch.nn.Linear
53+
self.linear = op(in_features, out_features)
54+
self.bn = torch.nn.BatchNorm1d(out_features)
55+
self.forward(torch.randn(2, 2) * 2 + 2) # update the BN stats
56+
57+
def forward(self, x):
58+
y = self.linear(x)
59+
y = self.bn(y)
60+
y = self.linear(y)
61+
y = y + y
62+
return self.bn(y)
63+
64+
def test_fp32_conv_batch_norm_fusion(self):
4665
for transpose in [False, True]:
4766
(
4867
Tester(
@@ -51,12 +70,12 @@ def test_fp32_batch_norm_fusion(self):
5170
)
5271
.export()
5372
.to_edge()
54-
.run_passes(self.PassStage)
73+
.run_passes(self.ConvPassStage)
5574
.check_count({self.bn_name: 1})
5675
.run_method_and_compare_outputs()
5776
)
5877

59-
def test_q8_batch_norm_fusion(self):
78+
def test_q8_conv_batch_norm_fusion(self):
6079
for transpose in [False, True]:
6180
(
6281
Tester(
@@ -66,12 +85,12 @@ def test_q8_batch_norm_fusion(self):
6685
.quantize()
6786
.export()
6887
.to_edge()
69-
.run_passes(self.PassStage)
88+
.run_passes(self.ConvPassStage)
7089
.check_count({self.bn_name: 1})
7190
.run_method_and_compare_outputs()
7291
)
7392

74-
def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
93+
def test_fp32_conv_batch_norm_no_fusion_doesnt_partition(self):
7594
"""
7695
We do not currently support standalone batch norms (i.e. batch norms that are
7796
not fused with a conv). This is planned, but until implemented, this test ensures
@@ -94,3 +113,39 @@ def forward(self, x):
94113
.partition()
95114
.check_count({self.bn_name: 1})
96115
)
116+
117+
def test_fp32_linear_batch_norm_fusion(self):
118+
(
119+
Tester(
120+
self.ModelLinearBN(2, 2).eval(),
121+
(torch.randn(2, 2),),
122+
)
123+
.export()
124+
.to_edge_transform_and_lower()
125+
.check_count({self.bn_name: 1})
126+
.run_method_and_compare_outputs()
127+
)
128+
129+
# def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self):
130+
# """
131+
# We do not currently support standalone batch norms (i.e. batch norms that are
132+
# not fused with a linear). This is planned, but until implemented, this test ensures
133+
# that we do not partition the standalone batch norm and then fail to lower.
134+
# """
135+
#
136+
# class BN(torch.nn.Module):
137+
# def __init__(self):
138+
# super().__init__()
139+
# self.bn = torch.nn.BatchNorm1d(2)
140+
#
141+
# def forward(self, x):
142+
# return self.bn(x)
143+
#
144+
# (
145+
# Tester(BN(), (torch.randn(2, 2),))
146+
# .export()
147+
# .to_edge()
148+
# .check_count({self.bn_name: 1})
149+
# .partition()
150+
# .check_count({self.bn_name: 1})
151+
# )

backends/xnnpack/test/passes/test_batch_norm_linear_fusion.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

0 commit comments

Comments
 (0)