Skip to content

Commit a78f023

Browse files
Arm backend: Fix get_module_name_filter (#15910)
Update TosaQuantizer to use get_module_name_filter from torchao. Adds new tests to validate that set_module_name works as intended. Fixes #15870 cc @freddan80 @per @zingo @digantdesai Signed-off-by: Oscar Andersson <[email protected]>
1 parent d2c011e commit a78f023

File tree

2 files changed

+161
-29
lines changed

2 files changed

+161
-29
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from torchao.quantization.pt2e.quantizer import (
4747
annotate_input_qspec_map,
4848
annotate_output_qspec,
49+
get_module_name_filter,
4950
QuantizationSpec,
5051
Quantizer,
5152
)
@@ -248,33 +249,6 @@ def get_symmetric_a16w8_quantization_config(
248249
"""
249250

250251

251-
def _get_module_name_filter(module_name: str) -> NodeFilterType:
252-
"""Get the module_name_filter function for a given module name, the filter accepts
253-
a node and checks if the node comes from a module that has certain module name
254-
255-
For example:
256-
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
257-
258-
>> module_name_filter = _get_module_name_filter("blocks.sub")
259-
>> print(module_name_filter(node))
260-
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
261-
"""
262-
263-
name_start = len("L['self'].")
264-
265-
def module_name_filter(n: Node) -> bool:
266-
# node_stack example: {
267-
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
268-
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
269-
# }
270-
# get_attr nodes doesn't have nn_module_stack?
271-
nn_module_stack = n.meta.get("nn_module_stack", {})
272-
names = [name[name_start:] for name, _ in nn_module_stack.values()]
273-
return module_name in names
274-
275-
return module_name_filter
276-
277-
278252
def _get_module_type_filter(tp: Callable) -> NodeFilterType:
279253
"""Get the module_type_filter function for a given module type, the filter accepts
280254
a node and checks if the node comes from a module that has certain module type
@@ -306,7 +280,7 @@ def _get_not_module_type_or_name_filter(
306280
tp_list: List[Callable], module_name_list: List[str]
307281
) -> NodeFilterType:
308282
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
309-
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
283+
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
310284

311285
def not_module_type_or_name_filter(n: Node) -> bool:
312286
return not any(f(n) for f in module_type_filters + module_name_list_filters)
@@ -455,7 +429,7 @@ def _annotate_for_static_quantization_config(
455429
module_name_list = list(self.module_name_config.keys())
456430
for module_name, config in self.module_name_config.items():
457431
self._annotate_all_static_patterns(
458-
model, config, _get_module_name_filter(module_name)
432+
model, config, get_module_name_filter(module_name)
459433
)
460434

461435
tp_list = list(self.module_type_config.keys())
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.quantizer import (
8+
get_symmetric_a16w8_quantization_config,
9+
get_symmetric_quantization_config,
10+
is_annotated,
11+
QuantizationConfig,
12+
TOSAQuantizer,
13+
)
14+
from executorch.backends.arm.quantizer.quantization_config import QuantizationSpec
15+
from executorch.backends.arm.tosa import TosaSpecification
16+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
17+
18+
DQ_PER_CHANNEL = torch.ops.quantized_decomposed.dequantize_per_channel.default
19+
DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default
20+
Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default
21+
22+
23+
class ConvModel(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
self.conv0 = torch.nn.Conv2d(
27+
3,
28+
16,
29+
kernel_size=4,
30+
)
31+
self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=3, bias=False)
32+
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3)
33+
34+
def forward(self, x):
35+
x = self.conv0(x)
36+
x = torch.sigmoid(x)
37+
x = self.conv1(x)
38+
x = torch.tanh(x)
39+
x = self.conv2(x)
40+
return x
41+
42+
43+
test_inputs = (torch.randn(1, 3, 64, 64),)
44+
45+
46+
def validate_per_tensor_quant(node: torch.fx.Node, qspec: QuantizationSpec):
47+
_, _, zero_point, qmin, qmax, dtype = node.args
48+
if qspec.qscheme == torch.per_tensor_symmetric:
49+
assert (
50+
zero_point == 0
51+
), f"Zero point {zero_point} is not zero for symmetric quantization"
52+
assert (
53+
qmin == qspec.quant_min
54+
), f"Quant min {qmin} does not match expected {qspec.quant_min}"
55+
assert (
56+
qmax == qspec.quant_max
57+
), f"Quant max {qmax} does not match expected {qspec.quant_max}"
58+
assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}"
59+
60+
61+
def validate_per_channel_quant(node: torch.fx.Node, qspec: QuantizationSpec):
62+
_, _, _, channel_axis, qmin, qmax, dtype = node.args
63+
assert (
64+
channel_axis == qspec.ch_axis
65+
), f"Channel axis {channel_axis} does not match expected {qspec.ch_axis}"
66+
assert (
67+
qmin == qspec.quant_min
68+
), f"Quant min {qmin} does not match expected {qspec.quant_min}"
69+
assert (
70+
qmax == qspec.quant_max
71+
), f"Quant max {qmax} does not match expected {qspec.quant_max}"
72+
assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}"
73+
74+
75+
def validate_input(input_node: torch.fx.Node, qspec: QuantizationSpec | None):
76+
if qspec is None:
77+
return
78+
79+
per_channel = qspec.qscheme == torch.per_channel_symmetric
80+
expected_dequant_op = DQ_PER_CHANNEL if per_channel else DQ_PER_TENSOR
81+
assert (
82+
input_node.target == expected_dequant_op
83+
), f"Input node {input_node} is not quantized as expected"
84+
if per_channel:
85+
validate_per_channel_quant(input_node, qspec)
86+
else:
87+
validate_per_tensor_quant(input_node, qspec)
88+
89+
90+
def validate_output(node: torch.fx.Node, qspec: QuantizationSpec | None):
91+
if qspec is None:
92+
return
93+
users = list(node.users)
94+
assert len(users) == 1, f"Node {node} should have exactly one user"
95+
assert (
96+
users[0].target == Q_PER_TENSOR
97+
), f"Output node {users[0]} is not quantized as expected"
98+
validate_per_tensor_quant(users[0], qspec)
99+
100+
101+
def validate_node(
102+
node: torch.fx.Node, quantization_config: QuantizationConfig | None
103+
) -> None:
104+
if quantization_config is None:
105+
assert not is_annotated(node), f"Node {node} is unexpectedly annotated"
106+
return
107+
108+
assert is_annotated(node), f"Node {node} is not annotated"
109+
input_qspec = quantization_config.get_input_act_qspec()
110+
output_qspec = quantization_config.get_output_act_qspec()
111+
weight_qspec = quantization_config.get_weight_qspec()
112+
113+
if len(node.all_input_nodes) == 3:
114+
input_node, weight_node, bias_node = node.all_input_nodes
115+
bias_qspec = quantization_config.get_bias_qspec(node)
116+
validate_input(bias_node, bias_qspec)
117+
else:
118+
input_node, weight_node = node.all_input_nodes
119+
120+
validate_input(input_node, input_qspec)
121+
validate_input(weight_node, weight_qspec)
122+
validate_output(node, output_qspec)
123+
124+
125+
def test_set_module_name() -> None:
126+
model = ConvModel()
127+
model.eval()
128+
129+
# Set up quantizer with different configs for different modules
130+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
131+
quantizer = TOSAQuantizer(tosa_spec)
132+
int8_config = get_symmetric_quantization_config(is_per_channel=False)
133+
a16w8_config = get_symmetric_a16w8_quantization_config()
134+
# Set module-specific configurations but don't set global config to test that
135+
# only specified modules are quantized
136+
quantizer.set_module_name("conv0", int8_config)
137+
quantizer.set_module_name("conv1", a16w8_config)
138+
139+
# Export model
140+
exported_model = torch.export.export(model, test_inputs)
141+
142+
# Prepare, calibrate and convert model
143+
prepared_model = prepare_pt2e(exported_model.module(), quantizer)
144+
prepared_model(*test_inputs)
145+
converted_model = convert_pt2e(prepared_model)
146+
147+
validate_node(
148+
[node for node in converted_model.graph.nodes if node.name == "conv2d"][0],
149+
int8_config,
150+
)
151+
validate_node(
152+
[node for node in converted_model.graph.nodes if node.name == "conv2d_1"][0],
153+
a16w8_config,
154+
)
155+
validate_node(
156+
[node for node in converted_model.graph.nodes if node.name == "conv2d_2"][0],
157+
None,
158+
)

0 commit comments

Comments
 (0)