Skip to content

Commit a5d7e5c

Browse files
authored
[ET-VK] Add Fusing for Conv/Binary Ops, Clamp/Binary Ops, and Clamp/Clamp (pytorch#14415)
With the motivation of improving performance, this change adds the functionality for fusing the following ops: - conv2d PW s1p0 and binary ops (add, sub, mul, div) - clamp and binary ops (add, sub, mul, div) - clamp and clamp cc @SS-JIA @manuelcandales @digantdesai @cbilgin
1 parent 84d060a commit a5d7e5c

File tree

11 files changed

+1190
-19
lines changed

11 files changed

+1190
-19
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class FuseClampBinaryOpPass(ExportPass):
18+
19+
FUSEABLE_CLAMP_OPS = [
20+
exir_ops.edge.aten.relu.default,
21+
exir_ops.edge.aten.hardtanh.default,
22+
exir_ops.edge.aten.clamp.default,
23+
]
24+
FUSEABLE_BINARY_OPS = [
25+
exir_ops.edge.aten.add.Tensor,
26+
exir_ops.edge.aten.sub.Tensor,
27+
exir_ops.edge.aten.mul.Tensor,
28+
exir_ops.edge.aten.div.Tensor,
29+
]
30+
31+
def exists_before(self, graph_module, node_a, node_b):
32+
seen_a = False
33+
for n in graph_module.graph.nodes:
34+
if n is node_a:
35+
seen_a = True
36+
if n is node_b:
37+
return seen_a
38+
return False
39+
40+
def get_output_min_max_from_activation(self, activation_node):
41+
if activation_node.target == exir_ops.edge.aten.relu.default:
42+
output_min = 0.0
43+
output_max = sys.float_info.max
44+
elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
45+
output_min = -1.0
46+
output_max = 1.0
47+
if len(activation_node.args) > 1:
48+
output_min = activation_node.args[1]
49+
output_max = activation_node.args[2]
50+
elif activation_node.target == exir_ops.edge.aten.clamp.default:
51+
output_min = None
52+
output_max = None
53+
if len(activation_node.args) >= 2:
54+
output_min = activation_node.args[1]
55+
if len(activation_node.args) >= 3:
56+
output_max = activation_node.args[2]
57+
58+
return output_min, output_max
59+
60+
def fuse_binary_op_with_clamp(self, graph_module: torch.fx.GraphModule):
61+
fuseAdded = False
62+
for clamp_node in graph_module.graph.nodes:
63+
if clamp_node.op == "call_function":
64+
if clamp_node.target in self.FUSEABLE_CLAMP_OPS:
65+
preceding_op = clamp_node.args[0]
66+
67+
if (
68+
preceding_op.op == "call_function"
69+
and preceding_op.target in self.FUSEABLE_BINARY_OPS
70+
):
71+
# Delete activation
72+
output_min_max = self.get_output_min_max_from_activation(
73+
clamp_node
74+
)
75+
new_args = list(preceding_op.args)
76+
new_args.append(output_min_max[0])
77+
new_args.append(output_min_max[1])
78+
new_args = tuple(new_args)
79+
clamp_node.replace_all_uses_with(preceding_op)
80+
graph_module.graph.erase_node(clamp_node)
81+
82+
new_op = None
83+
match preceding_op.target:
84+
case exir_ops.edge.aten.add.Tensor:
85+
new_op = (
86+
exir_ops.edge.et_vk.binary_add_with_clamp.default
87+
)
88+
case exir_ops.edge.aten.sub.Tensor:
89+
new_op = (
90+
exir_ops.edge.et_vk.binary_sub_with_clamp.default
91+
)
92+
case exir_ops.edge.aten.mul.Tensor:
93+
new_op = (
94+
exir_ops.edge.et_vk.binary_mul_with_clamp.default
95+
)
96+
case exir_ops.edge.aten.div.Tensor:
97+
new_op = (
98+
exir_ops.edge.et_vk.binary_div_with_clamp.default
99+
)
100+
101+
# Create and insert node of custom op `binary_<op>_with_clamp`
102+
with graph_module.graph.inserting_before(preceding_op):
103+
binary_op_clamp_node = graph_module.graph.create_node(
104+
"call_function",
105+
new_op,
106+
new_args,
107+
)
108+
109+
preceding_op.replace_all_uses_with(binary_op_clamp_node)
110+
graph_module.graph.erase_node(preceding_op)
111+
112+
fuseAdded = True
113+
114+
graph_module.recompile()
115+
graph_module = super().call(graph_module).graph_module
116+
return [fuseAdded, graph_module]
117+
118+
def call(self, graph_module: torch.fx.GraphModule):
119+
fuseAdded = True
120+
while fuseAdded:
121+
fuseAdded, graph_module = self.fuse_binary_op_with_clamp(graph_module)
122+
123+
return PassResult(graph_module, True)

backends/transforms/fuse_clamps.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class FuseClampsPass(ExportPass):
18+
19+
FUSEABLE_CLAMPS = [
20+
exir_ops.edge.aten.relu.default,
21+
exir_ops.edge.aten.hardtanh.default,
22+
exir_ops.edge.aten.clamp.default,
23+
]
24+
25+
def get_output_min_max_from_activation(self, activation_node):
26+
if activation_node.target == exir_ops.edge.aten.relu.default:
27+
output_min = 0.0
28+
output_max = sys.float_info.max
29+
elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
30+
output_min = -1.0
31+
output_max = 1.0
32+
if len(activation_node.args) > 1:
33+
output_min = activation_node.args[1]
34+
output_max = activation_node.args[2]
35+
elif activation_node.target == exir_ops.edge.aten.clamp.default:
36+
output_min = None
37+
output_max = None
38+
if len(activation_node.args) >= 2:
39+
output_min = activation_node.args[1]
40+
if len(activation_node.args) >= 3:
41+
output_max = activation_node.args[2]
42+
43+
return output_min, output_max
44+
45+
def call(self, graph_module: torch.fx.GraphModule):
46+
fuseAdded = True
47+
while fuseAdded:
48+
fuseAdded = False
49+
for clamp_2_node in graph_module.graph.nodes:
50+
if clamp_2_node.op == "call_function":
51+
if clamp_2_node.target in self.FUSEABLE_CLAMPS:
52+
preceding_op = clamp_2_node.args[0]
53+
if (
54+
preceding_op.op == "call_function"
55+
and preceding_op.target in self.FUSEABLE_CLAMPS
56+
):
57+
# Ensure the shapes match
58+
if (
59+
"val" not in clamp_2_node.args[0].meta
60+
or "val" not in preceding_op.args[0].meta
61+
):
62+
continue
63+
if len(clamp_2_node.args[0].meta["val"].shape) != len(
64+
preceding_op.args[0].meta["val"].shape
65+
):
66+
continue
67+
68+
min_max1 = self.get_output_min_max_from_activation(
69+
preceding_op
70+
)
71+
min_max2 = self.get_output_min_max_from_activation(
72+
clamp_2_node
73+
)
74+
75+
min_max = [None, None]
76+
77+
if min_max1[0] is None and min_max2[0] is not None:
78+
min_max[0] = min_max2[0]
79+
elif min_max1[0] is not None and min_max2[0] is None:
80+
min_max[0] = min_max1[0]
81+
else:
82+
min_max[0] = min(min_max1[0], min_max2[0])
83+
84+
if min_max1[1] is None and min_max2[1] is not None:
85+
min_max[1] = min_max2[1]
86+
elif min_max1[1] is not None and min_max2[1] is None:
87+
min_max[1] = min_max1[1]
88+
else:
89+
min_max[1] = max(min_max1[1], min_max2[1])
90+
91+
new_args = list(preceding_op.args)
92+
93+
# Insert the new min/max at indices 1 and 2
94+
new_args.insert(1, min_max[0])
95+
new_args.insert(2, min_max[1])
96+
new_args = new_args[0:3]
97+
preceding_op.args = tuple(new_args)
98+
clamp_2_node.replace_all_uses_with(preceding_op)
99+
graph_module.graph.erase_node(clamp_2_node)
100+
fuseAdded = True
101+
102+
graph_module.recompile()
103+
graph_module = super().call(graph_module).graph_module
104+
105+
return PassResult(graph_module, True)

backends/transforms/fuse_conv_with_clamp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

1616

17-
class FuseClampPass(ExportPass):
17+
class FuseConvClampPass(ExportPass):
1818
"""
1919
Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it.
2020
"""
@@ -25,6 +25,7 @@ class FuseClampPass(ExportPass):
2525
FUSEABLE_ACTIVATIONS = [
2626
exir_ops.edge.aten.relu.default,
2727
exir_ops.edge.aten.hardtanh.default,
28+
exir_ops.edge.aten.clamp.default,
2829
]
2930

3031
def get_output_min_max_from_activation(self, activation_node):
@@ -37,6 +38,13 @@ def get_output_min_max_from_activation(self, activation_node):
3738
if len(activation_node.args) > 1:
3839
output_min = activation_node.args[1]
3940
output_max = activation_node.args[2]
41+
elif activation_node.target == exir_ops.edge.aten.clamp.default:
42+
output_min = None
43+
output_max = None
44+
if len(activation_node.args) >= 2:
45+
output_min = activation_node.args[1]
46+
if len(activation_node.args) >= 3:
47+
output_max = activation_node.args[2]
4048

4149
return output_min, output_max
4250

backends/transforms/targets.bzl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,38 @@ def define_common_targets():
7777
],
7878
)
7979

80+
runtime.python_library(
81+
name = "fuse_clamps",
82+
srcs = ["fuse_clamps.py"],
83+
visibility = [
84+
"//executorch/backends/...",
85+
],
86+
deps = [
87+
":utils",
88+
"//caffe2:torch",
89+
"//executorch/backends/vulkan:custom_ops_lib",
90+
"//executorch/exir:pass_base",
91+
"//executorch/exir:sym_util",
92+
"//executorch/exir/dialects:lib",
93+
],
94+
)
95+
96+
runtime.python_library(
97+
name = "fuse_clamp_with_binary_op",
98+
srcs = ["fuse_clamp_with_binary_op.py"],
99+
visibility = [
100+
"//executorch/backends/...",
101+
],
102+
deps = [
103+
":utils",
104+
"//caffe2:torch",
105+
"//executorch/backends/vulkan:custom_ops_lib",
106+
"//executorch/exir:pass_base",
107+
"//executorch/exir:sym_util",
108+
"//executorch/exir/dialects:lib",
109+
],
110+
)
111+
80112
runtime.python_library(
81113
name = "view_copy_to_squeeze_unsqueeze",
82114
srcs = ["view_copy_to_squeeze_unsqueeze.py"],

0 commit comments

Comments
 (0)