Skip to content

Commit e09726d

Browse files
Qualcomm AI Engine Direct - Quantization Recipe for LLM
- add a fine-grained quantization annotation mechanism – quantization recipe - applied to Llama3-1B/3B with fine-grained quantization configs
1 parent c02fdfc commit e09726d

File tree

9 files changed

+1127
-371
lines changed

9 files changed

+1127
-371
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from enum import Enum, unique
6+
77
from typing import Sequence
88

99
import torch
@@ -17,7 +17,6 @@
1717
get_8a8w_qnn_ptq_config,
1818
get_8a8w_qnn_qat_config,
1919
get_ptq_per_channel_quant_config,
20-
get_qat_per_channel_quant_config,
2120
QuantizationConfig,
2221
)
2322
from executorch.exir.dialects._ops import ops as exir_ops
@@ -32,36 +31,6 @@
3231
)
3332

3433

35-
def annotate_down_proj(
36-
gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
37-
):
38-
for node in gm.graph.nodes:
39-
if (
40-
node.target == torch.ops.aten.conv2d.default
41-
and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"])
42-
and node.args[0].target == torch.ops.aten.mul.Tensor
43-
):
44-
input_qspec_map = {}
45-
input_qspec_map[node.args[0]] = quantization_config.input_activation
46-
input_qspec_map[node.args[1]] = quantization_config.weight
47-
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
48-
input_qspec_map=input_qspec_map,
49-
output_qspec=quantization_config.output_activation,
50-
_annotated=True,
51-
)
52-
53-
54-
@unique
55-
class StaticLLMQuantConfig(Enum):
56-
"""
57-
Layer namespace configuration for Qualcomm's static LLaMA quantization.
58-
"""
59-
60-
wq_sha = "wq_sha" # Query weight (single head)
61-
wk_sha = "wk_sha" # Key weight (single head)
62-
wv_sha = "wv_sha" # Value weight (single head)
63-
64-
6534
def annotate_eurobert(gm: torch.fx.GraphModule):
6635
"""
6736
QNN does not support int32 -> signed 16bit quant
@@ -123,49 +92,6 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule):
12392
break
12493

12594

126-
def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None:
127-
"""
128-
This function is for static LLM models.
129-
This function will annotate the last conv(linear), which is the lm_head, as 16a8w.
130-
"""
131-
132-
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
133-
input_qspec_map = {}
134-
input_act = node.args[0]
135-
input_spec = quantization_config.input_activation
136-
input_qspec_map[input_act] = input_spec
137-
138-
weight = node.args[1]
139-
input_qspec_map[weight] = quantization_config.weight
140-
141-
if len(node.args) > 2 and isinstance(node.args[2], Node):
142-
input_qspec_map[node.args[2]] = quantization_config.bias(node)
143-
144-
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
145-
input_qspec_map=input_qspec_map,
146-
output_qspec=quantization_config.output_activation,
147-
_annotated=True,
148-
)
149-
150-
if is_qat:
151-
quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config(
152-
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
153-
)
154-
else:
155-
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
156-
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
157-
)
158-
for node in gm.graph.nodes:
159-
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
160-
if "nn_module_stack" in node.meta:
161-
module_values_list = list(node.meta["nn_module_stack"].values())
162-
full_qualified_name = module_values_list[-1][0]
163-
if full_qualified_name == "output.conv":
164-
annotate_conv2d(
165-
node, quantization_config=quantization_config_16a8w_per_channel
166-
)
167-
168-
16995
def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
17096
for node in gm.graph.nodes:
17197
if node.op == "output":
@@ -200,48 +126,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
200126
)
201127

202128

203-
def annotate_qkv_proj_sha(
204-
gm: torch.fx.GraphModule,
205-
quantization_config: QuantizationConfig,
206-
qkv_tags: set[StaticLLMQuantConfig],
207-
):
208-
"""
209-
Annotates QKV projection layers in a GraphModule for quantization,
210-
specifically layers defined in StaticLLMQuantConfig.
211-
212-
Args:
213-
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
214-
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
215-
StaticLLMQuantConfig are allowed.
216-
217-
Raises:
218-
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
219-
"""
220-
221-
# Get all valid tags from the StaticLLMQuantConfig enum
222-
allowed_tags = set(StaticLLMQuantConfig)
223-
invalid_tags = qkv_tags - allowed_tags
224-
if invalid_tags:
225-
raise ValueError(
226-
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
227-
)
228-
229-
for node in gm.graph.nodes:
230-
if node.target == torch.ops.aten.conv2d.default and any(
231-
tag.value in node.meta["stack_trace"] for tag in qkv_tags
232-
):
233-
input_qspec_map = {}
234-
input_qspec_map[node.args[0]] = quantization_config.input_activation
235-
input_qspec_map[node.args[1]] = quantization_config.weight
236-
if len(node.args) > 2 and isinstance(node.args[2], Node):
237-
input_qspec_map[node.args[2]] = quantization_config.bias(node)
238-
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
239-
input_qspec_map=input_qspec_map,
240-
output_qspec=quantization_config.output_activation,
241-
_annotated=True,
242-
)
243-
244-
245129
def annotate_kv_8bit( # noqa: C901
246130
gm: torch.fx.GraphModule,
247131
is_qat=False,
@@ -262,7 +146,6 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
262146
input_act = node.args[0]
263147
input_spec = quantization_config.input_activation
264148
input_qspec_map[input_act] = input_spec
265-
266149
input_act1 = node.args[1]
267150
input_spec1 = quantization_config.weight
268151
input_qspec_map[input_act1] = input_spec1

backends/qualcomm/quantizer/qconfig.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,47 @@ def get_8a8w_qnn_ptq_config(
136136
return quantization_config
137137

138138

139+
def get_8a4w_qnn_ptq_config(
140+
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
141+
) -> QuantizationConfig:
142+
extra_args: Dict[str, Any] = {"eps": 2**-12}
143+
144+
act_quantization_spec = QuantizationSpec(
145+
dtype=torch.uint8,
146+
qscheme=(
147+
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
148+
),
149+
ch_axis=0,
150+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
151+
)
152+
153+
weight_quantization_spec = QuantizationSpec(
154+
dtype=torch.int8,
155+
quant_min=-7,
156+
quant_max=7,
157+
qscheme=torch.per_tensor_symmetric,
158+
ch_axis=0,
159+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
160+
)
161+
162+
bias_quantization_spec = QuantizationSpec(
163+
dtype=torch.int32,
164+
quant_min=torch.iinfo(torch.int32).min,
165+
quant_max=torch.iinfo(torch.int32).max,
166+
qscheme=torch.per_tensor_symmetric,
167+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
168+
)
169+
170+
quantization_config = QuantizationConfig(
171+
input_activation=act_quantization_spec,
172+
output_activation=act_quantization_spec,
173+
weight=weight_quantization_spec,
174+
bias=bias_quantization_spec,
175+
)
176+
177+
return quantization_config
178+
179+
139180
# 4 bits quantization only supports specific ops.
140181
def get_16a4w_qnn_ptq_config(
141182
act_observer=MovingAverageMinMaxObserver,

0 commit comments

Comments
 (0)