Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
scales, scale_offset, quantized_scales = quant_attrs[QCOM_SCALE], [], []
# channel in observers defaults to zero
num_channels = node.meta["val"].shape[0]
user_0 = self.get_first_user(node)

ch_axis = 0
# args[6] to check if it is transpose conv
if user_0.target == exir_ops.edge.aten.convolution.default and user_0.args[6]:
num_channels = node.meta["val"].shape[1]
ch_axis = 1
# TODO: expand this when QNN starts to support more configurations
bitwidth_of_scale = 4
quant_scales_dtype = torch.uint8
Expand All @@ -162,9 +169,10 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
)

for ch in range(num_channels):
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
candidates = scales[ch] if ch_axis == 0 else scales[:, ch, ...]
max_scale = candidates.reshape(1, -1).amax(dim=-1) / num_steps
q_scales = torch.clamp(
input=torch.round(input=scales[ch] / max_scale),
input=torch.round(input=candidates / max_scale),
min=1,
max=2**bitwidth_of_scale,
).to(quant_scales_dtype)
Expand All @@ -174,11 +182,11 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):

# skip dequantize op, e.g. frozen_param -> dq -> conv2d
user_0 = self.get_first_user(node)
if "convolution" in user_0.target.__name__:
if user_0.target == exir_ops.edge.aten.convolution.default:
# OIHW (pytorch) -> HWIO (QNN)
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0)
elif "linear" in user_0.target.__name__:
elif user_0.target == exir_ops.edge.aten.linear.default:
# OI (pytorch) -> OI (QNN)
quant_config[QCOM_AXIS] = 0
quant_config[QCOM_AXIS_ORDER] = (0, 1)
Expand Down Expand Up @@ -217,7 +225,7 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
# skip dequantize op, e.g. frozen_param -> dq -> conv2d
user_0 = self.get_first_user(node)
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
if "convolution" in user_0.target.__name__:
if user_0.target == exir_ops.edge.aten.convolution.default:
quant_config[QCOM_AXIS] = node.meta["val"].dim() - 1
else:
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
Expand Down
52 changes: 43 additions & 9 deletions backends/qualcomm/builders/op_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS

from .node_visitor import NodeVisitor
from .node_visitor import NodeVisitor, PER_CHANNEL_ENCODING
from .node_visitor_manager import register_node_visitor
from .qnn_constants import (
OpConv2d,
Expand Down Expand Up @@ -101,6 +101,29 @@ def _add_conv_op_parameter(

return conv_op

def _reduce_bias_scales(
self,
node: torch.fx.Node,
filter_node: torch.fx.Node,
bias_node: torch.fx.Node,
groups: int,
):
"""_summary_
If transpose_conv has groups, need special handle for bias_node's per channel quant.
Check _derived_bias_quant_spec under backends/qualcomm/quantizer/qconfig.py for more info.
"""

filter_scales = filter_node.meta[QCOM_QUANT_ATTRS]["scales"]
bias_scales = bias_node.meta[QCOM_QUANT_ATTRS]["scales"]
bias_zero_points = bias_node.meta[QCOM_QUANT_ATTRS]["zero_points"]

# Adding this condition to prevent reduce twice: op_validation and qnn_preprocess
if filter_scales.numel() != bias_scales.numel():
bias_scales = bias_scales.view(-1, groups)[:, 0]
bias_zero_points = bias_zero_points.view(-1, groups)[:, 0]
bias_node.meta[QCOM_QUANT_ATTRS]["scales"] = bias_scales
bias_node.meta[QCOM_QUANT_ATTRS]["zero_points"] = bias_zero_points

def define_node(
self,
node: torch.fx.Node,
Expand All @@ -127,8 +150,15 @@ def define_node(

filter_node = self.get_node(node.args[1])
filter_tensor = get_parameter(filter_node, self.edge_program)

stride = cast(List[int], node.args[3])
padding = cast(List[int], node.args[4])
dilation = cast(List[int], node.args[5])
output_padding = cast(List[int], node.args[7])
groups = cast(int, node.args[8])

# weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d),
# yet QNN is HWIO or DHWIO
# yet QNN is HWIO or DHWIO for both conv and conv_transpose.
is_transpose_conv = cast(bool, node.args[6])
if is_conv2d:
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
Expand All @@ -147,6 +177,16 @@ def define_node(
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
if node.args[2] is not None:
bias_node = self.get_node(node.args[2])
# TODO: Double check on condition below once QNN supports transpose_conv with block_quant.
# By checking node.args[1].target, only allow per_channel_quant to go through and bypass block_quant.
if (
is_transpose_conv
and groups != 1
and bias_node.meta.get(QCOM_QUANT_ATTRS) is not None
and node.args[1].target in PER_CHANNEL_ENCODING
):
self._reduce_bias_scales(node, filter_node, bias_node, groups)

bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
Expand All @@ -156,7 +196,6 @@ def define_node(
nodes_to_wrappers,
)
conv_input_tensors.append(bias_tensor_wrapper)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
Expand All @@ -167,11 +206,6 @@ def define_node(
)
conv_output_tensors = [output_tensor_wrapper]

stride = cast(List[int], node.args[3])
padding = cast(List[int], node.args[4])
dilation = cast(List[int], node.args[5])
output_padding = cast(List[int], node.args[7])
groups = cast(int, node.args[8])
# Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout)
group_input_channels = filter_tensor.shape[-2]
group_output_channels = int(filter_tensor.shape[-1] / groups)
Expand Down
5 changes: 2 additions & 3 deletions backends/qualcomm/builders/op_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ def define_node(
)
elu_op.AddInputTensors(elu_input_tensors)
elu_op.AddOutputTensors(elu_output_tensors)

if len(node.args) == 2:
if len(node.args) > 1:
elu_op.AddScalarParam(
OpElu.param_alpha,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: np.uint32(node.args[1])},
{QCOM_DATA: np.float32(node.args[1])},
)

return elu_op
28 changes: 23 additions & 5 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ def _derive_bias_qparams_fn(
act_scale, weight_scale
)
derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
# TransposeConv per channel axis=1, and the weight_shape[1] = out_channel / groups.
# E.g., out_channel = 6, groups = 2, weight_shape[1] = 3, which means there are 3 pairs of scale/offset.
# However, bias still has 6 values, meaning it requires repeat_interleave 2 times derived_scale in order to
# generate 6 pairs of scale/offset to perform per channel quantization. For bias node, Conv OP builder will later
# only pass 3 pairs of scale/offset to QNN.
if (
node.target
in {
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv_transpose3d.input,
}
and len(node.args) > 6
and node.args[6] != 1
):
groups = node.args[6]
derived_scale = derived_scale.repeat_interleave(groups)
derived_zero = torch.zeros(derived_scale.size(), device=weight_zp.device).to(
torch.int32
)
Expand All @@ -68,7 +84,6 @@ def _derive_bias_qparams_fn(
assert isinstance(input_act, Node)
weight = node.args[1]
assert isinstance(weight, Node)

return DerivedQuantizationSpec(
derived_from=[(input_act, node), (weight, node)],
derive_qparams_fn=_derive_bias_qparams_fn,
Expand Down Expand Up @@ -300,6 +315,7 @@ def get_ptq_per_channel_quant_config(
weight_dtype=torch.int8,
act_observer=MovingAverageMinMaxObserver,
act_symmetric: bool = False,
ch_axis: int = 0,
) -> QuantizationConfig:
extra_args: Dict[str, Any] = {"eps": 2**-12}

Expand Down Expand Up @@ -349,7 +365,7 @@ def get_ptq_per_channel_quant_config(
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
ch_axis=ch_axis,
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
)

Expand All @@ -370,6 +386,7 @@ def get_ptq_per_block_quant_config(
weight_dtype=torch.int8,
act_observer=MovingAverageMinMaxObserver,
act_symmetric: bool = False,
ch_axis: int = 0,
) -> QuantizationConfig:
extra_args: Dict[str, Any] = {"eps": 2**-12}
quantization_config = get_ptq_per_channel_quant_config(
Expand All @@ -385,7 +402,7 @@ def get_ptq_per_block_quant_config(
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
ch_axis=ch_axis,
observer_or_fake_quant_ctr=PerBlockParamObserver.with_args(**extra_args),
)
return QuantizationConfig(
Expand Down Expand Up @@ -522,6 +539,7 @@ def get_qat_per_channel_quant_config(
weight_dtype=torch.int8,
act_observer=MovingAverageMinMaxObserver,
act_symmetric=False,
ch_axis: int = 0,
) -> QuantizationConfig:
supported_act_types = {
torch.uint8,
Expand Down Expand Up @@ -577,7 +595,7 @@ def get_qat_per_channel_quant_config(
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
ch_axis=ch_axis,
observer=MovingAveragePerChannelMinMaxObserver,
)
weight_quantization_spec = QuantizationSpec(
Expand All @@ -587,7 +605,7 @@ def get_qat_per_channel_quant_config(
),
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
ch_axis=ch_axis,
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
)

Expand Down
85 changes: 60 additions & 25 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,33 +150,62 @@ def __post_init__(self):
if self.act_observer
else quant_config_func()
)
self.per_channel_quant_config = (
per_channel_quant_config_func(act_observer=self.act_observer)
if self.act_observer
else per_channel_quant_config_func()
)
self.use_per_channel_weight_quant_ops = set()

# Assume per_channel_quant/per_block_quant only happen on axis_0 or axis_1, increase the range if there's a need
potential_axis = 2

self.per_channel_quant_config_list = []
for i in range(potential_axis):
self.per_channel_quant_config_list.append(
(
per_channel_quant_config_func(
act_observer=self.act_observer, ch_axis=i
)
if self.act_observer
else per_channel_quant_config_func(ch_axis=i)
)
)

# Key is the node target, and value is the axis to perform per channel quantization
self.op_axis_dict = {
torch.ops.aten.conv1d.default: 0,
torch.ops.aten.conv2d.default: 0,
torch.ops.aten.conv3d.default: 0,
torch.ops.aten.conv_transpose2d.input: 1,
torch.ops.aten.conv_transpose3d.input: 1,
torch.ops.aten.linear.default: 0,
}

self.use_per_channel_weight_quant_ops = {}
if self.is_conv_per_channel:
conv_ops = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.conv_transpose3d.input,
]
self.use_per_channel_weight_quant_ops.update(
{
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv3d.default,
torch.ops.aten.conv_transpose2d.input,
}
{k: self.op_axis_dict[k] for k in conv_ops if k in self.op_axis_dict}
)
if self.is_linear_per_channel:
linear_ops = [torch.ops.aten.linear.default]
self.use_per_channel_weight_quant_ops.update(
{
torch.ops.aten.linear.default,
}
{k: self.op_axis_dict[k] for k in linear_ops if k in self.op_axis_dict}
)

if per_block_quant_config_func:
self.per_block_quant_config = (
per_block_quant_config_func(act_observer=self.act_observer)
if self.act_observer
else per_block_quant_config_func()
)
self.per_block_quant_config_list = []
for i in range(potential_axis):
self.per_block_quant_config_list.append(
(
per_block_quant_config_func(
act_observer=self.act_observer, ch_axis=i
)
if self.act_observer
else per_block_quant_config_func(ch_axis=i)
)
)


class QnnQuantizer(Quantizer):
Expand Down Expand Up @@ -269,16 +298,22 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]
op = node.target
if isinstance(op, str):
return

config = self._get_submodule_qconfig(node)
if block_size := self.block_size_map.get(node.name):
config = self.default_quant_config.per_block_quant_config
ch_axis = config.op_axis_dict.get(node.target, 0)
assert (
len(config.per_block_quant_config_list) > ch_axis
), f"Unsupported per block quantization axis: {ch_axis}, please increase the range of per_block_quant_config_list"
config = config.per_block_quant_config_list[ch_axis]
config.block_size = block_size
return config

config = self._get_submodule_qconfig(node)

if op in config.use_per_channel_weight_quant_ops:
return config.per_channel_quant_config
ch_axis = config.use_per_channel_weight_quant_ops[op]
assert (
len(config.per_channel_quant_config_list) > ch_axis
), f"Unsupported per channel quantization axis: {ch_axis}, please increase the range of per_channel_quant_config_list"
return config.per_channel_quant_config_list[ch_axis]

if op in self.quant_ops:
return config.quant_config
Expand Down
23 changes: 17 additions & 6 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,15 +746,26 @@ def forward(self, x):


class ConvTranspose2dSingle(torch.nn.Module):
def __init__(self, bias=True, dilation=1):
def __init__(
self,
bias=True,
in_channels=1,
out_channels=3,
kernel_size=1,
stride=1,
padding=1,
dilation=1,
groups=1,
):
super().__init__()
self.conv_transpose = torch.nn.ConvTranspose2d(
in_channels=1,
out_channels=3,
kernel_size=3,
stride=2,
padding=1,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)

Expand Down
Loading
Loading