diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 85f65ebbcbb..78d1e6244e9 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -40,6 +40,14 @@ def define_node( linear_input_tensors.append(input_tensor_wrapper) weight_node = node.args[1] + if ( + quant_attrs := weight_node.meta.get("quant_attrs") + ) and "scales" in quant_attrs: + # Dimension of weight is [m, n], per channel quant params is [m] + # Change to [m, 1] to fit the tensor.div(s).add(z) + quant_attrs["scales"] = quant_attrs["scales"].reshape([-1, 1]) + quant_attrs["zero_points"] = quant_attrs["zero_points"].reshape([-1, 1]) + weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, @@ -52,6 +60,12 @@ def define_node( if len(node.args) >= 3: bias_node = node.args[2] + + # TODO remove this when qnn sdk support + if "scales" in bias_node.meta.get("quant_attrs"): + print( + f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet." + ) bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 674314d991c..1414af171a4 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -267,7 +267,7 @@ def __init__(self): self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() - self.enable_per_channel_conv_quant: bool = True + self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() # the weight quantized for activation 8 bits and 16 bits self.per_channel_weight_dtype: Dict = { "8bit_act": torch.int8, @@ -290,16 +290,13 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None: def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]: """ Priority: - 1. per channel config when enable_per_channel_conv_quant is True + 1. is one of use_per_channel_weight_quant_ops 2. int8 / int16 config """ if type(op) == str: return - if self.enable_per_channel_conv_quant and op in [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - ]: + if op in self.use_per_channel_weight_quant_ops: if op in self.bit16_quant_ops: return get_ptq_per_channel_weight_config( torch.uint16, self.per_channel_weight_dtype["16bit_act"] @@ -316,6 +313,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig print(f"No quant config is implemented for op, {op}") + def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): + if enable: + self.use_per_channel_weight_quant_ops.update(ops) + else: + self.use_per_channel_weight_quant_ops.difference(ops) + def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: for op in ops: assert ( @@ -368,8 +371,15 @@ def set_per_channel_weight_dtype( if weight_dtype_for_16bit_act: self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act - def set_per_channel_quant(self, enable: bool) -> None: - self.enable_per_channel_conv_quant = enable + def set_per_channel_conv_quant(self, enable: bool) -> None: + conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} + self._update_per_channel_weight_quant_ops(conv_ops, enable) + + def set_per_channel_linear_quant(self, enable: bool) -> None: + linear_ops = { + torch.ops.aten.linear.default, + } + self._update_per_channel_weight_quant_ops(linear_ops, enable) def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = RemoveClone()(model).graph_module diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index 809b7298eba..77e16efa9b3 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -520,11 +520,11 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( - node, - bias_node, - quantization_config.bias, - ) + if callable(quantization_config.bias): + bias_config = quantization_config.bias(node) + else: + bias_config = quantization_config.bias + _annotate_input_qspec_map(node, bias_node, bias_config) nodes_to_mark_annotated.append(bias_node) _annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index bbc89276854..edc7a469f7b 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -409,9 +409,9 @@ def forward(self, x): class Linear(torch.nn.Module): - def __init__(self): + def __init__(self, use_bias: bool = True): super().__init__() - self.linear = torch.nn.Linear(4, 5).eval() + self.linear = torch.nn.Linear(4, 5, use_bias).eval() def forward(self, x): return self.linear(x) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index b15a876a1f8..d539827fdb9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -505,7 +505,33 @@ def test_qnn_backend_16a4w_linear(self): module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) module = self.get_qdq_module( - module, sample_input, quant_dtype=QuantDtype.use_16a4w + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w, + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_16a4w_per_channel_linear(self): + module = Linear(use_bias=False) # noqa: F405 + sample_input = (torch.randn([3, 4]),) + module = self.get_qdq_module( + module, + sample_input, + is_linear_per_channel=True, + quant_dtype=QuantDtype.use_16a4w, + ) + self.lower_module_and_test_output(module, sample_input) + + # Is not enabled in the current qnn sdk release + @unittest.expectedFailure + def test_qnn_backend_16a4w_per_channel_linear_with_bias(self): + module = Linear() # noqa: F405 + sample_input = (torch.randn([3, 4]),) + module = self.get_qdq_module( + module, + sample_input, + is_linear_per_channel=True, + quant_dtype=QuantDtype.use_16a4w, ) self.lower_module_and_test_output(module, sample_input) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 5700b5fb17a..59a48f123da 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -225,6 +225,7 @@ def get_qdq_module( module: torch.nn.Module, inputs: Tuple[torch.Tensor], is_conv_per_channel: Optional[bool] = True, + is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), quant_dtype: QuantDtype = QuantDtype.use_8a8w, ) -> torch.fx.GraphModule: @@ -232,7 +233,8 @@ def get_qdq_module( quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_quant_annotations) - quantizer.set_per_channel_quant(is_conv_per_channel) + quantizer.set_per_channel_conv_quant(is_conv_per_channel) + quantizer.set_per_channel_linear_quant(is_linear_per_channel) if quant_dtype == QuantDtype.use_8a8w: pass # default setting