Skip to content

Commit 810b8e8

Browse files
authored
Qualcomm AI Engine Direct - Suite operator fix part 3 (#15182)
### Summary - Supports transpose_conv with per_channel_quant. Previously, transpose_conv per_channel_quant is not done correctly, as the channel should be ch_axis=1 instead of ch_axis=0. - Some test suite operators are using random values instead of desired input. These operations require specific range to prevent `NaN` as output. ### Test plan UT added cc @cccclai @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin
1 parent 9287f6d commit 810b8e8

File tree

11 files changed

+409
-76
lines changed

11 files changed

+409
-76
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
153153
scales, scale_offset, quantized_scales = quant_attrs[QCOM_SCALE], [], []
154154
# channel in observers defaults to zero
155155
num_channels = node.meta["val"].shape[0]
156+
user_0 = self.get_first_user(node)
157+
158+
ch_axis = 0
159+
# args[6] to check if it is transpose conv
160+
if user_0.target == exir_ops.edge.aten.convolution.default and user_0.args[6]:
161+
num_channels = node.meta["val"].shape[1]
162+
ch_axis = 1
156163
# TODO: expand this when QNN starts to support more configurations
157164
bitwidth_of_scale = 4
158165
quant_scales_dtype = torch.uint8
@@ -162,9 +169,10 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
162169
)
163170

164171
for ch in range(num_channels):
165-
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
172+
candidates = scales[ch] if ch_axis == 0 else scales[:, ch, ...]
173+
max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps
166174
q_scales = torch.clamp(
167-
input=torch.round(input=scales[ch] / max_scale),
175+
input=torch.round(input=candidates / max_scale),
168176
min=1,
169177
max=2**bitwidth_of_scale,
170178
).to(quant_scales_dtype)
@@ -174,11 +182,11 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
174182

175183
# skip dequantize op, e.g. frozen_param -> dq -> conv2d
176184
user_0 = self.get_first_user(node)
177-
if "convolution" in user_0.target.__name__:
185+
if user_0.target == exir_ops.edge.aten.convolution.default:
178186
# OIHW (pytorch) -> HWIO (QNN)
179187
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
180188
quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0)
181-
elif "linear" in user_0.target.__name__:
189+
elif user_0.target == exir_ops.edge.aten.linear.default:
182190
# OI (pytorch) -> OI (QNN)
183191
quant_config[QCOM_AXIS] = 0
184192
quant_config[QCOM_AXIS_ORDER] = (0, 1)
@@ -217,7 +225,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
217225
# skip dequantize op, e.g. frozen_param -> dq -> conv2d
218226
user_0 = self.get_first_user(node)
219227
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
220-
if "convolution" in user_0.target.__name__:
228+
if user_0.target == exir_ops.edge.aten.convolution.default:
221229
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
222230
else:
223231
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]

backends/qualcomm/builders/op_conv.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
1010
import numpy as np
1111
import torch
12-
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
1313

14-
from .node_visitor import NodeVisitor
14+
from .node_visitor import NodeVisitor, PER_CHANNEL_ENCODING
1515
from .node_visitor_manager import register_node_visitor
1616
from .qnn_constants import (
1717
OpConv2d,
@@ -101,6 +101,29 @@ def _add_conv_op_parameter(
101101

102102
return conv_op
103103

104+
def _reduce_bias_scales(
105+
self,
106+
node: torch.fx.Node,
107+
filter_node: torch.fx.Node,
108+
bias_node: torch.fx.Node,
109+
groups: int,
110+
):
111+
"""_summary_
112+
If transpose_conv has groups, need special handle for bias_node's per channel quant.
113+
Check _derived_bias_quant_spec under backends/qualcomm/quantizer/qconfig.py for more info.
114+
"""
115+
116+
filter_scales = filter_node.meta[QCOM_QUANT_ATTRS]["scales"]
117+
bias_scales = bias_node.meta[QCOM_QUANT_ATTRS]["scales"]
118+
bias_zero_points = bias_node.meta[QCOM_QUANT_ATTRS]["zero_points"]
119+
120+
# Adding this condition to prevent reduce twice: op_validation and qnn_preprocess
121+
if filter_scales.numel() != bias_scales.numel():
122+
bias_scales = bias_scales.view(-1, groups)[:, 0]
123+
bias_zero_points = bias_zero_points.view(-1, groups)[:, 0]
124+
bias_node.meta[QCOM_QUANT_ATTRS]["scales"] = bias_scales
125+
bias_node.meta[QCOM_QUANT_ATTRS]["zero_points"] = bias_zero_points
126+
104127
def define_node(
105128
self,
106129
node: torch.fx.Node,
@@ -127,8 +150,15 @@ def define_node(
127150

128151
filter_node = self.get_node(node.args[1])
129152
filter_tensor = get_parameter(filter_node, self.edge_program)
153+
154+
stride = cast(List[int], node.args[3])
155+
padding = cast(List[int], node.args[4])
156+
dilation = cast(List[int], node.args[5])
157+
output_padding = cast(List[int], node.args[7])
158+
groups = cast(int, node.args[8])
159+
130160
# weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
131-
# yet QNN is HWIO or DHWIO
161+
# yet QNN is HWIO or DHWIO for both conv and conv_transpose.
132162
is_transpose_conv = cast(bool, node.args[6])
133163
if is_conv2d:
134164
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
@@ -147,6 +177,16 @@ def define_node(
147177
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
148178
if node.args[2] is not None:
149179
bias_node = self.get_node(node.args[2])
180+
# TODO: Double check on condition below once QNN supports transpose_conv with block_quant.
181+
# By checking node.args[1].target, only allow per_channel_quant to go through and bypass block_quant.
182+
if (
183+
is_transpose_conv
184+
and groups != 1
185+
and bias_node.meta.get(QCOM_QUANT_ATTRS) is not None
186+
and node.args[1].target in PER_CHANNEL_ENCODING
187+
):
188+
self._reduce_bias_scales(node, filter_node, bias_node, groups)
189+
150190
bias_tensor = get_parameter(bias_node, self.edge_program)
151191
bias_tensor_wrapper = self.define_tensor(
152192
bias_node,
@@ -156,7 +196,6 @@ def define_node(
156196
nodes_to_wrappers,
157197
)
158198
conv_input_tensors.append(bias_tensor_wrapper)
159-
160199
output_tensor = self.get_tensor(node, node)
161200
output_tensor_wrapper = self.define_tensor(
162201
node,
@@ -167,11 +206,6 @@ def define_node(
167206
)
168207
conv_output_tensors = [output_tensor_wrapper]
169208

170-
stride = cast(List[int], node.args[3])
171-
padding = cast(List[int], node.args[4])
172-
dilation = cast(List[int], node.args[5])
173-
output_padding = cast(List[int], node.args[7])
174-
groups = cast(int, node.args[8])
175209
# Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
176210
group_input_channels = filter_tensor.shape[-2]
177211
group_output_channels = int(filter_tensor.shape[-1] / groups)

backends/qualcomm/builders/op_elu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,11 @@ def define_node(
5858
)
5959
elu_op.AddInputTensors(elu_input_tensors)
6060
elu_op.AddOutputTensors(elu_output_tensors)
61-
62-
if len(node.args) == 2:
61+
if len(node.args) > 1:
6362
elu_op.AddScalarParam(
6463
OpElu.param_alpha,
6564
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
66-
{QCOM_DATA: np.uint32(node.args[1])},
65+
{QCOM_DATA: np.float32(node.args[1])},
6766
)
6867

6968
return elu_op

backends/qualcomm/quantizer/qconfig.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ def _derive_bias_qparams_fn(
5252
act_scale, weight_scale
5353
)
5454
derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
55+
# TransposeConv per channel axis=1, and the weight_shape[1] = out_channel / groups.
56+
# E.g., out_channel = 6, groups = 2, weight_shape[1] = 3, which means there are 3 pairs of scale/offset.
57+
# However, bias still has 6 values, meaning it requires repeat_interleave 2 times derived_scale in order to
58+
# generate 6 pairs of scale/offset to perform per channel quantization. For bias node, Conv OP builder will later
59+
# only pass 3 pairs of scale/offset to QNN.
60+
if (
61+
node.target
62+
in {
63+
torch.ops.aten.conv_transpose2d.input,
64+
torch.ops.aten.conv_transpose3d.input,
65+
}
66+
and len(node.args) > 6
67+
and node.args[6] != 1
68+
):
69+
groups = node.args[6]
70+
derived_scale = derived_scale.repeat_interleave(groups)
5571
derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to(
5672
torch.int32
5773
)
@@ -68,7 +84,6 @@ def _derive_bias_qparams_fn(
6884
assert isinstance(input_act, Node)
6985
weight = node.args[1]
7086
assert isinstance(weight, Node)
71-
7287
return DerivedQuantizationSpec(
7388
derived_from=[(input_act, node), (weight, node)],
7489
derive_qparams_fn=_derive_bias_qparams_fn,
@@ -300,6 +315,7 @@ def get_ptq_per_channel_quant_config(
300315
weight_dtype=torch.int8,
301316
act_observer=MovingAverageMinMaxObserver,
302317
act_symmetric: bool = False,
318+
ch_axis: int = 0,
303319
) -> QuantizationConfig:
304320
extra_args: Dict[str, Any] = {"eps": 2**-12}
305321

@@ -349,7 +365,7 @@ def get_ptq_per_channel_quant_config(
349365
),
350366
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
351367
qscheme=torch.per_channel_symmetric,
352-
ch_axis=0,
368+
ch_axis=ch_axis,
353369
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
354370
)
355371

@@ -370,6 +386,7 @@ def get_ptq_per_block_quant_config(
370386
weight_dtype=torch.int8,
371387
act_observer=MovingAverageMinMaxObserver,
372388
act_symmetric: bool = False,
389+
ch_axis: int = 0,
373390
) -> QuantizationConfig:
374391
extra_args: Dict[str, Any] = {"eps": 2**-12}
375392
quantization_config = get_ptq_per_channel_quant_config(
@@ -385,7 +402,7 @@ def get_ptq_per_block_quant_config(
385402
),
386403
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
387404
qscheme=torch.per_channel_symmetric,
388-
ch_axis=0,
405+
ch_axis=ch_axis,
389406
observer_or_fake_quant_ctr=PerBlockParamObserver.with_args(**extra_args),
390407
)
391408
return QuantizationConfig(
@@ -522,6 +539,7 @@ def get_qat_per_channel_quant_config(
522539
weight_dtype=torch.int8,
523540
act_observer=MovingAverageMinMaxObserver,
524541
act_symmetric=False,
542+
ch_axis: int = 0,
525543
) -> QuantizationConfig:
526544
supported_act_types = {
527545
torch.uint8,
@@ -577,7 +595,7 @@ def get_qat_per_channel_quant_config(
577595
),
578596
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
579597
qscheme=torch.per_channel_symmetric,
580-
ch_axis=0,
598+
ch_axis=ch_axis,
581599
observer=MovingAveragePerChannelMinMaxObserver,
582600
)
583601
weight_quantization_spec = QuantizationSpec(
@@ -587,7 +605,7 @@ def get_qat_per_channel_quant_config(
587605
),
588606
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
589607
qscheme=torch.per_channel_symmetric,
590-
ch_axis=0,
608+
ch_axis=ch_axis,
591609
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
592610
)
593611

backends/qualcomm/quantizer/quantizer.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,33 +150,62 @@ def __post_init__(self):
150150
if self.act_observer
151151
else quant_config_func()
152152
)
153-
self.per_channel_quant_config = (
154-
per_channel_quant_config_func(act_observer=self.act_observer)
155-
if self.act_observer
156-
else per_channel_quant_config_func()
157-
)
158-
self.use_per_channel_weight_quant_ops = set()
153+
154+
# Assume per_channel_quant/per_block_quant only happen on axis_0 or axis_1, increase the range if there's a need
155+
potential_axis = 2
156+
157+
self.per_channel_quant_config_list = []
158+
for i in range(potential_axis):
159+
self.per_channel_quant_config_list.append(
160+
(
161+
per_channel_quant_config_func(
162+
act_observer=self.act_observer, ch_axis=i
163+
)
164+
if self.act_observer
165+
else per_channel_quant_config_func(ch_axis=i)
166+
)
167+
)
168+
169+
# Key is the node target, and value is the axis to perform per channel quantization
170+
self.op_axis_dict = {
171+
torch.ops.aten.conv1d.default: 0,
172+
torch.ops.aten.conv2d.default: 0,
173+
torch.ops.aten.conv3d.default: 0,
174+
torch.ops.aten.conv_transpose2d.input: 1,
175+
torch.ops.aten.conv_transpose3d.input: 1,
176+
torch.ops.aten.linear.default: 0,
177+
}
178+
179+
self.use_per_channel_weight_quant_ops = {}
159180
if self.is_conv_per_channel:
181+
conv_ops = [
182+
torch.ops.aten.conv1d.default,
183+
torch.ops.aten.conv2d.default,
184+
torch.ops.aten.conv3d.default,
185+
torch.ops.aten.conv_transpose2d.input,
186+
torch.ops.aten.conv_transpose3d.input,
187+
]
160188
self.use_per_channel_weight_quant_ops.update(
161-
{
162-
torch.ops.aten.conv1d.default,
163-
torch.ops.aten.conv2d.default,
164-
torch.ops.aten.conv3d.default,
165-
torch.ops.aten.conv_transpose2d.input,
166-
}
189+
{k: self.op_axis_dict[k] for k in conv_ops if k in self.op_axis_dict}
167190
)
168191
if self.is_linear_per_channel:
192+
linear_ops = [torch.ops.aten.linear.default]
169193
self.use_per_channel_weight_quant_ops.update(
170-
{
171-
torch.ops.aten.linear.default,
172-
}
194+
{k: self.op_axis_dict[k] for k in linear_ops if k in self.op_axis_dict}
173195
)
196+
174197
if per_block_quant_config_func:
175-
self.per_block_quant_config = (
176-
per_block_quant_config_func(act_observer=self.act_observer)
177-
if self.act_observer
178-
else per_block_quant_config_func()
179-
)
198+
self.per_block_quant_config_list = []
199+
for i in range(potential_axis):
200+
self.per_block_quant_config_list.append(
201+
(
202+
per_block_quant_config_func(
203+
act_observer=self.act_observer, ch_axis=i
204+
)
205+
if self.act_observer
206+
else per_block_quant_config_func(ch_axis=i)
207+
)
208+
)
180209

181210

182211
class QnnQuantizer(Quantizer):
@@ -269,16 +298,22 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]
269298
op = node.target
270299
if isinstance(op, str):
271300
return
272-
301+
config = self._get_submodule_qconfig(node)
273302
if block_size := self.block_size_map.get(node.name):
274-
config = self.default_quant_config.per_block_quant_config
303+
ch_axis = config.op_axis_dict.get(node.target, 0)
304+
assert (
305+
len(config.per_block_quant_config_list) > ch_axis
306+
), f"Unsupported per block quantization axis: {ch_axis}, please increase the range of per_block_quant_config_list"
307+
config = config.per_block_quant_config_list[ch_axis]
275308
config.block_size = block_size
276309
return config
277310

278-
config = self._get_submodule_qconfig(node)
279-
280311
if op in config.use_per_channel_weight_quant_ops:
281-
return config.per_channel_quant_config
312+
ch_axis = config.use_per_channel_weight_quant_ops[op]
313+
assert (
314+
len(config.per_channel_quant_config_list) > ch_axis
315+
), f"Unsupported per channel quantization axis: {ch_axis}, please increase the range of per_channel_quant_config_list"
316+
return config.per_channel_quant_config_list[ch_axis]
282317

283318
if op in self.quant_ops:
284319
return config.quant_config

backends/qualcomm/tests/models.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -746,15 +746,26 @@ def forward(self, x):
746746

747747

748748
class ConvTranspose2dSingle(torch.nn.Module):
749-
def __init__(self, bias=True, dilation=1):
749+
def __init__(
750+
self,
751+
bias=True,
752+
in_channels=1,
753+
out_channels=3,
754+
kernel_size=1,
755+
stride=1,
756+
padding=1,
757+
dilation=1,
758+
groups=1,
759+
):
750760
super().__init__()
751761
self.conv_transpose = torch.nn.ConvTranspose2d(
752-
in_channels=1,
753-
out_channels=3,
754-
kernel_size=3,
755-
stride=2,
756-
padding=1,
762+
in_channels=in_channels,
763+
out_channels=out_channels,
764+
kernel_size=kernel_size,
765+
stride=stride,
766+
padding=padding,
757767
dilation=dilation,
768+
groups=groups,
758769
bias=bias,
759770
)
760771

0 commit comments

Comments
 (0)