Skip to content

Commit ce107df

Browse files
Quantize Addmm, Conv2d, Linear, Mm together with fusable activations
+ Move fused activations to separate QDQ cluster
1 parent ce377b7 commit ce107df

14 files changed

+947
-155
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,32 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
710
from torch.fx import GraphModule, Node
811
from torch.nn import Parameter
912

1013

14+
QUANTIZE_OPERATORS = [
15+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
16+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
17+
]
18+
19+
DEQUANTIZE_OPERATORS = [
20+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
21+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
22+
]
23+
24+
25+
def _is_dequantize(node_: Node) -> bool:
26+
return node_.op == "call_function" and node_.target in DEQUANTIZE_OPERATORS
27+
28+
29+
def _is_quantize(node_: Node) -> bool:
30+
return node_.op == "call_function" and node_.target in QUANTIZE_OPERATORS
31+
32+
1133
def input_tensor(node: Node, input_index: int) -> torch.Tensor:
1234
if len(node.all_input_nodes) <= input_index:
1335
raise IndexError
@@ -62,12 +84,6 @@ def node_is_effectively_static_tensor(
6284
if node_is_static_tensor(node, parameters_mapping):
6385
return True
6486

65-
def _is_dequantize(node_: Node) -> bool:
66-
return node_.target.__name__ in {
67-
"quantized_decomposed.dequantize_per_tensor.default",
68-
"quantized_decomposed.dequantize_per_channel.default",
69-
}
70-
7187
return _is_dequantize(node) and node_is_static_tensor(
7288
node.args[0], parameters_mapping
7389
)

backends/nxp/backend/neutron_target_spec.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,92 @@
77

88
from enum import Enum
99

10+
import torch
11+
1012
from executorch.backends.nxp.backend.neutron_converter_manager import (
1113
NeutronConverterManager,
1214
)
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from torch.fx import Node
1318

1419

1520
class NeutronHWVersion(Enum):
1621
N1 = 1
1722
N3 = 2
1823

1924

25+
class NeutronTargetNeutronC:
26+
@staticmethod
27+
def is_supported_fused_activation__aten(node_: Node) -> bool:
28+
"""Node operator is supported fused activation on Neutron for Linear and Conv2D."""
29+
return node_.op == "call_function" and (
30+
node_.target
31+
in (
32+
torch.ops.aten.relu.default, # TODO Add torch.ops.aten.leaky_relu.default once it is supported
33+
torch.ops.aten.relu_.default,
34+
torch.ops.aten.sigmoid.default,
35+
torch.ops.aten.sigmoid_.default,
36+
torch.ops.aten.tanh.default,
37+
torch.ops.aten.tanh_.default,
38+
)
39+
or (
40+
(
41+
node_.target == torch.ops.aten.hardtanh.default
42+
or node_.target == torch.ops.aten.hardtanh_.default
43+
)
44+
and (
45+
node_.args[1:3] == (0.0, 6.0) # is converted to Relu6
46+
or node_.args[1:3] == (0.0, float("inf")) # is converted to Relu
47+
)
48+
)
49+
)
50+
51+
@staticmethod
52+
def is_supported_fused_activation__edge(node_: Node) -> bool:
53+
"""Node operator is supported fused activation on Neutron for Linear and Conv2D."""
54+
return node_.op == "call_function" and (
55+
node_.target
56+
in (
57+
exir_ops.edge.aten.relu.default, # TODO Add torch.ops.aten.leaky_relu.default once it is supported
58+
exir_ops.edge.aten.sigmoid.default,
59+
exir_ops.edge.aten.tanh.default,
60+
)
61+
or (
62+
(node_.target == exir_ops.edge.aten.hardtanh.default)
63+
and (
64+
node_.args[1:3] == (0.0, 6.0) # is converted to Relu6
65+
or node_.args[1:3] == (0.0, float("inf")) # is converted to Relu
66+
)
67+
)
68+
)
69+
70+
@staticmethod
71+
def is_fusable_conv_or_linear__aten(node_: Node) -> bool:
72+
"""Node operator is supported fusable Linear or Conv2D on Neutron."""
73+
return node_.op == "call_function" and (
74+
node_.target == torch.ops.aten.conv2d.default
75+
or node_.target == torch.ops.aten.addmm.default
76+
or node_.target == torch.ops.aten.mm.default
77+
or (
78+
node_.target == torch.ops.aten.linear.default
79+
and len(node_.meta["val"].shape) == 2
80+
)
81+
)
82+
83+
@staticmethod
84+
def is_fusable_conv_or_linear__edge(node_: Node) -> bool:
85+
"""Node operator in edge dialect is supported fusable Linear or Conv2D on Neutron."""
86+
return node_.op == "call_function" and (
87+
node_.target == exir_ops.edge.aten.addmm.default
88+
or node_.target == exir_ops.edge.aten.mm.default
89+
or (
90+
node_.target == exir_ops.edge.aten.convolution.default
91+
and len(node_.meta["val"].shape) == 4
92+
)
93+
)
94+
95+
2096
class NeutronTargetSpec:
2197
"""
2298
The functionality for probing the properties of Neutron Target.
@@ -39,6 +115,9 @@ def __init__(self, target: str, neutron_converter_flavor: str):
39115
f"Target `{target}` contains unsupported HW version. Only N3/N3+ targets are supported at the moment."
40116
)
41117

118+
# Now only Neutron-C is supported
119+
self.neutron_target_info = NeutronTargetNeutronC()
120+
42121
# Target name.
43122
def get_name(self) -> str:
44123
return self.neutron_target.name

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
AddMM = exir_ops.edge.aten.addmm.default
1616
ViewCopy = exir_ops.edge.aten.view_copy.default
1717
MM = exir_ops.edge.aten.mm.default
18+
Conv = exir_ops.edge.aten.convolution.default
19+
HardTanh = exir_ops.edge.aten.hardtanh.default
20+
Relu = exir_ops.edge.aten.relu.default
21+
Sigmoid = exir_ops.edge.aten.sigmoid.default
22+
Tanh = exir_ops.edge.aten.tanh.default
1823

1924

2025
def insert_qdq_pair_after_node(
@@ -175,9 +180,23 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
175180
main_cluster_node_to_auxiliary_nodes = {
176181
AddMM: [
177182
ViewCopy,
183+
HardTanh,
184+
Relu,
185+
Sigmoid,
186+
Tanh,
178187
],
179188
MM: [
180189
ViewCopy,
190+
HardTanh,
191+
Relu,
192+
Sigmoid,
193+
Tanh,
194+
],
195+
Conv: [
196+
HardTanh,
197+
Relu,
198+
Sigmoid,
199+
Tanh,
181200
],
182201
}
183202

backends/nxp/neutron_partitioner.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from executorch.backends.nxp.backend.custom_delegation_options import (
1616
CustomDelegationOptions,
1717
)
18+
from executorch.backends.nxp.backend.edge_helper import (
19+
DEQUANTIZE_OPERATORS,
20+
QUANTIZE_OPERATORS,
21+
)
1822
from executorch.backends.nxp.backend.edge_program_converter import (
1923
EdgeProgramToIRConverter,
2024
)
@@ -66,32 +70,26 @@ class QDQCluster:
6670
compute_node: torch.fx.Node
6771
ops: list[torch.fx.Node]
6872

69-
QUANTIZE_OPERATORS = [
70-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
71-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
72-
]
73-
74-
DEQUANTIZE_OPERATORS = [
75-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
76-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
77-
]
78-
7973
AUXILIARY_OPS = [
8074
operator.getitem,
8175
exir_ops.edge.aten.view_copy.default,
8276
exir_ops.edge.aten.permute_copy.default,
77+
exir_ops.edge.aten.hardtanh.default,
78+
exir_ops.edge.aten.relu.default,
79+
exir_ops.edge.aten.sigmoid.default,
80+
exir_ops.edge.aten.tanh.default,
8381
]
8482

8583
def __init__(self):
8684
self.cluster_map: dict[str, QDQClusterRecognizer.QDQCluster] = {}
8785

8886
@staticmethod
8987
def is_quant_node(node: torch.fx.Node) -> bool:
90-
return node.target in QDQClusterRecognizer.QUANTIZE_OPERATORS
88+
return node.target in QUANTIZE_OPERATORS
9189

9290
@staticmethod
9391
def is_dequant_node(node: torch.fx.Node) -> bool:
94-
return node.target in QDQClusterRecognizer.DEQUANTIZE_OPERATORS
92+
return node.target in DEQUANTIZE_OPERATORS
9593

9694
@staticmethod
9795
def is_auxiliary_node(node: torch.fx.Node) -> bool:
@@ -308,18 +306,17 @@ class NeutronPartitioner(Partitioner):
308306
def __init__(
309307
self,
310308
compile_spec: list[CompileSpec],
309+
neutron_target_spec: NeutronTargetSpec,
311310
custom_delegation_options: CustomDelegationOptions | None = None,
312311
) -> None:
313312
self.delegation_spec = DelegationSpec(NeutronBackend.__name__, compile_spec)
314313
self.custom_delegation_options = (
315314
custom_delegation_options or CustomDelegationOptions()
316315
)
317-
target = self.delegation_spec[1][2].value.decode()
318-
converter_flavor = self.delegation_spec[1][3].value.decode()
319-
self.neutron_target_spec = NeutronTargetSpec(target, converter_flavor)
316+
self.neutron_target_spec = neutron_target_spec
320317

318+
@staticmethod
321319
def validate_partitioning_result(
322-
self,
323320
graph: Graph,
324321
partition_list: list[Partition],
325322
custom_delegation_options: CustomDelegationOptions,

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
98
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
109
NeutronAtenPassManager,
1110
)
11+
12+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1213
from executorch.backends.nxp.quantizer.patterns import (
1314
AbsPattern,
1415
AdaptiveAvgPoolPattern,
@@ -181,27 +182,28 @@ def get_supported_operators(cls) -> list[OperatorConfig]:
181182

182183

183184
class NeutronQuantizer(ComposableQuantizer):
184-
def __init__(self):
185+
def __init__(self, neutron_target_spec: NeutronTargetSpec):
186+
self.neutron_target_spec = neutron_target_spec
185187
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
186188
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
187189
super().__init__(
188190
[
189191
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
190192
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
191193
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
192-
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig),
194+
NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig),
193195
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
194196
NeutronAtenQuantizer(CatPattern(), static_qconfig),
195197
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
196-
NeutronAtenQuantizer(Conv2dPattern(), static_qconfig),
198+
NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig),
197199
NeutronAtenQuantizer(DropoutPattern(), static_qconfig),
198200
NeutronAtenQuantizer(FlattenPattern(), static_qconfig),
199201
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
200202
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
201-
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
203+
NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig),
202204
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
203205
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
204-
NeutronAtenQuantizer(MmPattern(), static_qconfig),
206+
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
205207
NeutronAtenQuantizer(PadPattern(), static_qconfig),
206208
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
207209
NeutronAtenQuantizer(ReluPattern(), static_qconfig),

0 commit comments

Comments
 (0)