Skip to content

Commit 8003ba3

Browse files
committed
Register Linear fusion pass in BatchNormConfig
1 parent 98a954c commit 8003ba3

File tree

3 files changed

+40
-34
lines changed

3 files changed

+40
-34
lines changed

backends/xnnpack/_passes/fuse_batch_norm_with_linear.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4141
):
4242
continue
4343

44-
# Single user of the linear op must be batch_norm
44+
# Single user of the linear op must be batch_norm. If not, bail.
4545
bn = list(linear.users.keys())[0]
4646
if (
4747
bn.target != exir_ops.edge.aten.native_batch_norm.default
@@ -53,7 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule):
5353
if not self.can_fuse(linear, bn, self.exported_program):
5454
continue
5555

56-
# Get the parameters
56+
# Get the parameters from linear op
5757
assert len(linear.args) == 3
5858

5959
linear_weight = get_param_tensor(self.exported_program, linear.args[1])
@@ -171,14 +171,17 @@ def can_fuse(
171171
].count(False):
172172
return False
173173

174+
linear_weights = linear.args[1]
174175
bn_weights = bn.args[1]
175176

176177
# Check that the weights for linear and batchnorm are both params
177-
if not isinstance(linear, torch.fx.Node) or not isinstance(
178+
if not isinstance(linear_weights, torch.fx.Node) or not isinstance(
178179
bn_weights, torch.fx.Node
179180
):
180181
return False
181182

182-
if [is_param_node(program, node) for node in {linear, bn_weights}].count(False):
183+
if [
184+
is_param_node(program, node) for node in {linear_weights, bn_weights}
185+
].count(False):
183186
return False
184187
return True

backends/xnnpack/partition/config/node_configs.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
1313
FuseBatchNormWithConvPass,
1414
)
15+
from executorch.backends.xnnpack._passes.fuse_batch_norm_with_linear import (
16+
FuseBatchNormWithLinearPass,
17+
)
1518
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1619
ConfigPrecisionType,
1720
XNNPartitionerConfig,
@@ -35,20 +38,22 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
3538
return False
3639

3740
bn = node
38-
conv = node.all_input_nodes[0]
41+
input_node = node.all_input_nodes[0]
3942

40-
if conv.op != "call_function":
43+
if input_node.op != "call_function":
4144
return False
4245

43-
conv_name = format_target_name(conv.target.__name__) # pyre-ignore
46+
input_name = format_target_name(input_node.target.__name__) # pyre-ignore
4447

45-
if conv_name not in ["convolution.default"]:
46-
why(node, f"Invalid conv target {conv_name}")
48+
if input_name not in ["convolution.default", "linear.default"]:
49+
why(node, f"Invalid input target {input_name.split('.')[0]}")
4750
return False
4851

49-
can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep)
52+
can_fuse = FuseBatchNormWithConvPass.can_fuse(
53+
input_node, bn, ep
54+
) or FuseBatchNormWithLinearPass.can_fuse(input_node, bn, ep)
5055
if not can_fuse:
51-
why(node, "BatchNorm cannot be fused with Convolution")
56+
why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}")
5257
return False
5358

5459
return True

backends/xnnpack/test/passes/test_batch_norm_fusion.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,26 +126,24 @@ def test_fp32_linear_batch_norm_fusion(self):
126126
.run_method_and_compare_outputs()
127127
)
128128

129-
# def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self):
130-
# """
131-
# We do not currently support standalone batch norms (i.e. batch norms that are
132-
# not fused with a linear). This is planned, but until implemented, this test ensures
133-
# that we do not partition the standalone batch norm and then fail to lower.
134-
# """
135-
#
136-
# class BN(torch.nn.Module):
137-
# def __init__(self):
138-
# super().__init__()
139-
# self.bn = torch.nn.BatchNorm1d(2)
140-
#
141-
# def forward(self, x):
142-
# return self.bn(x)
143-
#
144-
# (
145-
# Tester(BN(), (torch.randn(2, 2),))
146-
# .export()
147-
# .to_edge()
148-
# .check_count({self.bn_name: 1})
149-
# .partition()
150-
# .check_count({self.bn_name: 1})
151-
# )
129+
def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self):
130+
"""
131+
We do not currently support standalone batch norms (i.e. batch norms that are
132+
not fused with a linear). This is planned, but until implemented, this test ensures
133+
that we do not partition the standalone batch norm and then fail to lower.
134+
"""
135+
136+
class BN(torch.nn.Module):
137+
def __init__(self):
138+
super().__init__()
139+
self.bn = torch.nn.BatchNorm1d(2)
140+
141+
def forward(self, x):
142+
return self.bn(x)
143+
144+
(
145+
Tester(BN(), (torch.randn(2, 2),))
146+
.export()
147+
.to_edge_transform_and_lower()
148+
.check_count({self.bn_name: 1})
149+
)

0 commit comments

Comments
 (0)