Skip to content

Commit 9a1dff2

Browse files
authored
Merge pull request #7 from anzr299/origin/nncf_compression
[NNCF] WC Support in OVQuantizer
2 parents 1716834 + 198190e commit 9a1dff2

File tree

2 files changed

+228
-56
lines changed

2 files changed

+228
-56
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from torch.ao.quantization.observer import MappingType, PerGroup, PerAxis, PerChannelMinMaxObserver, get_block_size
11+
from torch.ao.quantization.pt2e._affine_quantization import (
12+
_get_reduction_params,
13+
AffineQuantizedMinMaxObserver,
14+
)
15+
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor, INT4SymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor, INT8SymmetricWeightsDecompressor
16+
from nncf.experimental.torch.fx.transformations import constant_update_fn, module_insertion_transformation_builder
17+
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
18+
from nncf.torch.graph.transformations.commands import PTTargetPoint, TargetType
19+
20+
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
21+
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
22+
from nncf.parameters import CompressWeightsMode
23+
from nncf.tensor.tensor import Tensor
24+
25+
class PTPerBlockParamObserver(AffineQuantizedMinMaxObserver):
26+
def __init__(self, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
qmode = CompressWeightsMode.INT4_ASYM if self.mapping_type==MappingType.ASYMMETRIC else CompressWeightsMode.INT4_SYM
29+
assert isinstance(self.granularity, PerGroup), "Only PerGroup granularity is supported"
30+
self.wc_config = WeightCompressionConfig(mode=qmode, group_size=self.granularity.group_size)
31+
32+
def calculate_qparams(self, weight):
33+
assert hasattr(self, "min_val") and hasattr(
34+
self, "max_val"
35+
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
36+
_, reduction_dims = _get_reduction_params(
37+
self.block_size, weight.size()
38+
)
39+
assert len(reduction_dims) == 1, "Only 1-D group size is supported"
40+
reduction_dims = reduction_dims[0] - 1
41+
q_weight, scale, zp = do_integer_quantization(Tensor(weight), self.wc_config, reduction_axes=reduction_dims)
42+
zp = zp.data if zp is not None else None
43+
return q_weight.data, scale.data, zp
44+
45+
def convert(self, model: torch.fx.GraphModule, observer_node: torch.fx.Node):
46+
print("calling convert")
47+
assert (
48+
self.original_dtype is not None
49+
), "Expecting original_dtype to be populated"
50+
weight_node = observer_node.args[0]
51+
original_weight = get_tensor_constant_from_node(weight_node, model)
52+
q_weight, scale, zero_point = self.calculate_qparams(original_weight)
53+
54+
with model.graph.inserting_before(observer_node):
55+
if(zero_point is not None):
56+
decompressor = INT4AsymmetricWeightsDecompressor(scale, zero_point, q_weight.shape, original_weight.shape, original_weight.dtype)
57+
else:
58+
decompressor = INT4SymmetricWeightsDecompressor(scale, q_weight.shape, original_weight.shape, original_weight.dtype)
59+
packed_q_weight = decompressor.pack_weight(q_weight)
60+
new_weight_node = constant_update_fn(model, observer_node, packed_q_weight, input_port_id=0)
61+
decompressor_name = f'NNCFDecompressor_{new_weight_node.name}'
62+
63+
module_insertion_transformation_builder(
64+
decompressor,
65+
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=new_weight_node.name)],
66+
decompressor_name,
67+
)(model)
68+
decomp_node = observer_node.args[0]
69+
observer_node.replace_all_uses_with(decomp_node)
70+
model.graph.erase_node(observer_node)
71+
72+
73+
class NNCFInt8observer(PerChannelMinMaxObserver):
74+
def __init__(self, *args, **kwargs):
75+
super().__init__(*args, **kwargs)
76+
qmode = CompressWeightsMode.INT8_SYM if self.qscheme==torch.per_channel_symmetric else CompressWeightsMode.INT8_ASYM
77+
self.wc_config = WeightCompressionConfig(mode=qmode)
78+
79+
def calculate_qparams(self, weight):
80+
assert hasattr(self, "min_val") and hasattr(
81+
self, "max_val"
82+
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
83+
self.granularity = PerAxis(axis=self.ch_axis)
84+
self.block_size = get_block_size(weight.shape, self.granularity)
85+
_, reduction_dims = _get_reduction_params(
86+
self.block_size, weight.size()
87+
)
88+
q_weight, scale, zp = do_integer_quantization(Tensor(weight), self.wc_config, reduction_axes=reduction_dims)
89+
zp = zp.data if zp is not None else None
90+
return q_weight.data, scale.data, zp
91+
92+
def convert(self, model: torch.fx.GraphModule, observer_node: torch.fx.Node):
93+
print("calling convert")
94+
weight_node = observer_node.args[0]
95+
original_weight = get_tensor_constant_from_node(weight_node, model)
96+
q_weight, scale, zero_point = self.calculate_qparams(original_weight)
97+
98+
with model.graph.inserting_before(observer_node):
99+
if(zero_point is not None):
100+
decompressor = INT8AsymmetricWeightsDecompressor(scale, zero_point, original_weight.dtype)
101+
else:
102+
decompressor = INT8SymmetricWeightsDecompressor(scale, original_weight.dtype)
103+
packed_q_weight = decompressor.pack_weight(q_weight)
104+
new_weight_node = constant_update_fn(model, observer_node, packed_q_weight, input_port_id=0)
105+
decompressor_name = f'NNCFDecompressor_{new_weight_node.name}'
106+
107+
module_insertion_transformation_builder(
108+
decompressor,
109+
[PTTargetPoint(TargetType.OPERATOR_POST_HOOK, target_node_name=new_weight_node.name)],
110+
decompressor_name,
111+
)(model)
112+
decomp_node = observer_node.args[0]
113+
observer_node.replace_all_uses_with(decomp_node)
114+
model.graph.erase_node(observer_node)

backends/openvino/quantizer/quantizer.py

Lines changed: 114 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
HistogramObserver,
2222
PerChannelMinMaxObserver,
2323
UniformQuantizationObserverBase,
24+
PerGroup,
25+
MappingType,
2426
)
2527
from torchao.quantization.pt2e.quantizer import (
2628
EdgeOrNode,
@@ -30,6 +32,9 @@
3032
Quantizer,
3133
SharedQuantizationSpec,
3234
)
35+
from nncf.quantization.quantize_model import get_weight_compression_configuration
36+
from nncf.common.quantization.structs import QuantizerConfig, QuantizationScheme
37+
from executorch.backends.openvino.quantizer.observers.nncf_observers import PTPerBlockParamObserver,NNCFInt8observer
3338

3439
QUANT_ANNOTATION_KEY = "quantization_annotation"
3540

@@ -46,6 +51,10 @@ class QuantizationMode(Enum):
4651
INT8_SYM = "int8_sym"
4752
INT8_MIXED = "int8_mixed"
4853
INT8_TRANSFORMER = "int8_transformer"
54+
INT8_SYM_WC = "int8_sym_wc"
55+
INT8_ASYM_WC = "int8_asym_wc"
56+
INT4_SYM_WC = "int4_sym"
57+
INT4_ASYM_WC = "int4_asym"
4958

5059

5160
class OpenVINOQuantizer(Quantizer):
@@ -66,8 +75,12 @@ def __init__(
6675
- INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights.
6776
- INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models
6877
Default value is INT8_SYM.
78+
- INT4_SYM: Symmetric INT4 Weights-Only Compression
79+
- INT4_ASYM: Asymmetric INT4 Weights-Only Compression
6980
:param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm.
7081
"""
82+
self.mode = mode
83+
self.wc_modes = [QuantizationMode.INT4_ASYM_WC,QuantizationMode.INT4_SYM_WC, QuantizationMode.INT8_ASYM_WC, QuantizationMode.INT8_SYM_WC]
7184
if mode == QuantizationMode.INT8_SYM:
7285
preset = quantization.structs.QuantizationPreset.PERFORMANCE
7386
model_type = None
@@ -77,11 +90,24 @@ def __init__(
7790
else:
7891
preset = None
7992
model_type = nncf.parameters.ModelType.TRANSFORMER
80-
self._min_max_algo = (
81-
nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization(
82-
preset=preset, model_type=model_type, **kwargs
93+
if(self.mode not in self.wc_modes):
94+
self._min_max_algo = (
95+
nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization(
96+
preset=preset, model_type=model_type, **kwargs
97+
)
8398
)
84-
)
99+
self._algo = self._min_max_algo
100+
else:
101+
weight_compression_configuration = get_weight_compression_configuration(
102+
mode.value.replace("_wc", ""), # Mode value has to match NNCF CompressWeightsMode
103+
**kwargs
104+
)
105+
self._weight_compression_algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression(
106+
subset_size=None,
107+
**weight_compression_configuration
108+
)
109+
self._algo = self._weight_compression_algo
110+
85111

86112
def set_ignored_scope(
87113
self,
@@ -102,7 +128,7 @@ def set_ignored_scope(
102128
:param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match
103129
in the model graph.
104130
"""
105-
self._min_max_algo.set_ignored_scope(
131+
self._algo.set_ignored_scope(
106132
nncf.IgnoredScope(
107133
names=names or [],
108134
patterns=patterns or [],
@@ -115,63 +141,80 @@ def set_ignored_scope(
115141
def get_nncf_quantization_setup(
116142
self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph
117143
) -> quantization.quantizer_setup.SingleConfigQuantizerSetup:
118-
self._min_max_algo._set_backend_entity(model)
119-
return self._min_max_algo.find_quantization_setup(model, nncf_graph)
144+
self._algo._set_backend_entity(model)
145+
return self._algo.find_quantization_setup(model, nncf_graph)
120146

121147
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
122148
nncf_graph = nncf_fx.nncf_graph_builder.GraphConverter.create_nncf_graph(model)
123-
quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph)
124-
149+
125150
graph = model.graph
126151
node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = (
127152
defaultdict(QuantizationAnnotation)
128153
)
154+
# Serperate into annotation for quantize and compress
155+
if(self.mode in self.wc_modes):
156+
self._algo.set_backend_entity(model)
157+
nodes_to_compress = self._algo.get_nodes_to_compress(nncf_graph)
158+
for node in nodes_to_compress:
159+
quantization_insertion_point = quantization.quantizer_setup.WeightQuantizationInsertionPoint(target_node_name=node.node_name)
160+
group_size = self._algo._group_size
161+
num_bits = 4 if self.mode in [QuantizationMode.INT4_SYM_WC,QuantizationMode.INT4_ASYM_WC] else 8
162+
qmode = QuantizationScheme.SYMMETRIC if self.mode in [QuantizationMode.INT4_SYM_WC,QuantizationMode.INT8_SYM_WC] else QuantizationScheme.ASYMMETRIC
163+
nncf_qconfig = QuantizerConfig(num_bits=num_bits, mode=qmode)
164+
qp = quantization.quantizer_setup.SingleConfigQuantizationPoint(qip=quantization_insertion_point, qconfig=nncf_qconfig, directly_quantized_operator_node_names=[node])
165+
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
166+
graph, nncf_graph, qp, node_vs_torch_annotation
167+
)
168+
qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_nncf_config(qp, group_size=group_size, weights_only=True)
169+
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
170+
else:
171+
quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph)
129172

130-
for qp in quantization_setup.quantization_points.values():
131-
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
132-
graph, nncf_graph, qp, node_vs_torch_annotation
133-
)
134-
qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_qp(qp)
135-
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
173+
for qp in quantization_setup.quantization_points.values():
174+
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
175+
graph, nncf_graph, qp, node_vs_torch_annotation
176+
)
177+
qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_nncf_config(qp)
178+
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
136179

137-
for quantizer_ids in quantization_setup.unified_scale_groups.values():
180+
for quantizer_ids in quantization_setup.unified_scale_groups.values():
138181

139-
root_quantizer_id = self._get_unified_scales_root_quantizer_id(
140-
nncf_graph, quantizer_ids, quantization_setup
141-
)
142-
root_qp = quantization_setup.quantization_points[root_quantizer_id]
182+
root_quantizer_id = self._get_unified_scales_root_quantizer_id(
183+
nncf_graph, quantizer_ids, quantization_setup
184+
)
185+
root_qp = quantization_setup.quantization_points[root_quantizer_id]
143186

144-
if any(
145-
root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig
146-
for q_id in quantizer_ids
147-
):
148-
qps = [
149-
quantization_setup.quantization_points[q_id]
187+
if any(
188+
root_qp.qconfig != quantization_setup.quantization_points[q_id].qconfig
150189
for q_id in quantizer_ids
151-
]
152-
msg = (
153-
"Different quantization configs are set to one unified scale group:"
154-
f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}"
190+
):
191+
qps = [
192+
quantization_setup.quantization_points[q_id]
193+
for q_id in quantizer_ids
194+
]
195+
msg = (
196+
"Different quantization configs are set to one unified scale group:"
197+
f"{[(qp.insertion_point.__dict__, str(qp.qconfig)) for qp in qps]}"
198+
)
199+
raise nncf.InternalError(msg)
200+
201+
root_target_node = nncf_fx.node_utils.get_graph_node_by_name(
202+
graph, root_qp.insertion_point.target_node_name
203+
)
204+
root_edge_or_node = self._get_edge_or_node(
205+
root_target_node, root_qp, nncf_graph
155206
)
156-
raise nncf.InternalError(msg)
157-
158-
root_target_node = nncf_fx.node_utils.get_graph_node_by_name(
159-
graph, root_qp.insertion_point.target_node_name
160-
)
161-
root_edge_or_node = self._get_edge_or_node(
162-
root_target_node, root_qp, nncf_graph
163-
)
164207

165-
for quantizer_id in quantizer_ids:
166-
if quantizer_id == root_quantizer_id:
167-
continue
208+
for quantizer_id in quantizer_ids:
209+
if quantizer_id == root_quantizer_id:
210+
continue
168211

169-
qspec = SharedQuantizationSpec(root_edge_or_node)
170-
qp = quantization_setup.quantization_points[quantizer_id]
171-
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
172-
graph, nncf_graph, qp, node_vs_torch_annotation
173-
)
174-
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
212+
qspec = SharedQuantizationSpec(root_edge_or_node)
213+
qp = quantization_setup.quantization_points[quantizer_id]
214+
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
215+
graph, nncf_graph, qp, node_vs_torch_annotation
216+
)
217+
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
175218

176219
for node, annotation in node_vs_torch_annotation.items():
177220
assert QUANT_ANNOTATION_KEY not in node.meta
@@ -295,8 +338,8 @@ def _fill_torch_ao_annotation(
295338
annotation_to_update.input_qspec_map[edge_or_node[0]] = qspec
296339

297340
@staticmethod
298-
def _get_torch_ao_qspec_from_qp(
299-
qp: quantization.quantizer_setup.QuantizationPointBase,
341+
def _get_torch_ao_qspec_from_nncf_config(
342+
qp: quantization.quantizer_setup.QuantizationPointBase, group_size=-1, weights_only=False
300343
) -> QuantizationSpec:
301344
"""
302345
Retrieves the quantization configuration from the given quantization point and
@@ -307,11 +350,10 @@ def _get_torch_ao_qspec_from_qp(
307350
"""
308351
# Eps value is copied from nncf/torch/quantization/layers.py
309352
extra_args = {"eps": 1e-16}
310-
qconfig = qp.qconfig
311353
is_weight = qp.is_weight_quantization_point()
354+
qconfig = qp.qconfig
312355

313356
observer: Type[UniformQuantizationObserverBase]
314-
315357
if qconfig.per_channel:
316358
torch_qscheme = (
317359
torch.per_channel_symmetric
@@ -325,11 +367,27 @@ def _get_torch_ao_qspec_from_qp(
325367
else torch.per_tensor_affine
326368
)
327369
if is_weight:
328-
observer = PerChannelMinMaxObserver
329-
quant_min = -128
330-
quant_max = 127
331-
dtype = torch.int8
332-
channel_axis = 0
370+
mapping_type = MappingType.SYMMETRIC if qconfig.mode == QuantizationScheme.SYMMETRIC else MappingType.ASYMMETRIC
371+
if qconfig.num_bits==4:
372+
extra_args["mapping_type"] = mapping_type
373+
extra_args["target_dtype"] = torch.int8
374+
extra_args["granularity"] = PerGroup(group_size=group_size)
375+
observer = PTPerBlockParamObserver
376+
quant_min = -8
377+
quant_max = 7
378+
dtype = torch.int8
379+
channel_axis = 0
380+
elif qconfig.num_bits==8:
381+
observer = NNCFInt8observer if weights_only else PerChannelMinMaxObserver
382+
quant_min = -128
383+
quant_max = 127
384+
dtype = torch.int8
385+
channel_axis = 0
386+
torch_qscheme = (
387+
torch.per_channel_symmetric
388+
if qconfig.mode is quantization.structs.QuantizationScheme.SYMMETRIC
389+
else torch.per_channel_affine
390+
)
333391
else:
334392
observer = (
335393
HistogramObserver

0 commit comments

Comments
 (0)