2121 HistogramObserver ,
2222 PerChannelMinMaxObserver ,
2323 UniformQuantizationObserverBase ,
24+ PerGroup ,
25+ MappingType ,
2426)
2527from torchao .quantization .pt2e .quantizer import (
2628 EdgeOrNode ,
3032 Quantizer ,
3133 SharedQuantizationSpec ,
3234)
35+ from nncf .quantization .quantize_model import get_weight_compression_configuration
36+ from nncf .common .quantization .structs import QuantizerConfig , QuantizationScheme
37+ from executorch .backends .openvino .quantizer .observers .nncf_observers import PTPerBlockParamObserver ,NNCFInt8observer
3338
3439QUANT_ANNOTATION_KEY = "quantization_annotation"
3540
@@ -46,6 +51,10 @@ class QuantizationMode(Enum):
4651 INT8_SYM = "int8_sym"
4752 INT8_MIXED = "int8_mixed"
4853 INT8_TRANSFORMER = "int8_transformer"
54+ INT8_SYM_WC = "int8_sym_wc"
55+ INT8_ASYM_WC = "int8_asym_wc"
56+ INT4_SYM_WC = "int4_sym"
57+ INT4_ASYM_WC = "int4_asym"
4958
5059
5160class OpenVINOQuantizer (Quantizer ):
@@ -66,8 +75,12 @@ def __init__(
6675 - INT8_MIXED: INT8 asymmetric quantization for activations, symmetric for weights.
6776 - INT8_TRANSFORMER: Optimized INT8 quantization for transformer-based models
6877 Default value is INT8_SYM.
78+ - INT4_SYM: Symmetric INT4 Weights-Only Compression
79+ - INT4_ASYM: Asymmetric INT4 Weights-Only Compression
6980 :param kwargs: Arguments to pass to the NNCF MinMaxQuantization algorithm.
7081 """
82+ self .mode = mode
83+ self .wc_modes = [QuantizationMode .INT4_ASYM_WC ,QuantizationMode .INT4_SYM_WC , QuantizationMode .INT8_ASYM_WC , QuantizationMode .INT8_SYM_WC ]
7184 if mode == QuantizationMode .INT8_SYM :
7285 preset = quantization .structs .QuantizationPreset .PERFORMANCE
7386 model_type = None
@@ -77,11 +90,24 @@ def __init__(
7790 else :
7891 preset = None
7992 model_type = nncf .parameters .ModelType .TRANSFORMER
80- self ._min_max_algo = (
81- nncf .quantization .algorithms .min_max .algorithm .MinMaxQuantization (
82- preset = preset , model_type = model_type , ** kwargs
93+ if (self .mode not in self .wc_modes ):
94+ self ._min_max_algo = (
95+ nncf .quantization .algorithms .min_max .algorithm .MinMaxQuantization (
96+ preset = preset , model_type = model_type , ** kwargs
97+ )
8398 )
84- )
99+ self ._algo = self ._min_max_algo
100+ else :
101+ weight_compression_configuration = get_weight_compression_configuration (
102+ mode .value .replace ("_wc" , "" ), # Mode value has to match NNCF CompressWeightsMode
103+ ** kwargs
104+ )
105+ self ._weight_compression_algo = nncf .quantization .algorithms .weight_compression .algorithm .WeightCompression (
106+ subset_size = None ,
107+ ** weight_compression_configuration
108+ )
109+ self ._algo = self ._weight_compression_algo
110+
85111
86112 def set_ignored_scope (
87113 self ,
@@ -102,7 +128,7 @@ def set_ignored_scope(
102128 :param validate: If set to True, then a RuntimeError will be raised if any ignored scope does not match
103129 in the model graph.
104130 """
105- self ._min_max_algo .set_ignored_scope (
131+ self ._algo .set_ignored_scope (
106132 nncf .IgnoredScope (
107133 names = names or [],
108134 patterns = patterns or [],
@@ -115,63 +141,80 @@ def set_ignored_scope(
115141 def get_nncf_quantization_setup (
116142 self , model : torch .fx .GraphModule , nncf_graph : NNCFGraph
117143 ) -> quantization .quantizer_setup .SingleConfigQuantizerSetup :
118- self ._min_max_algo ._set_backend_entity (model )
119- return self ._min_max_algo .find_quantization_setup (model , nncf_graph )
144+ self ._algo ._set_backend_entity (model )
145+ return self ._algo .find_quantization_setup (model , nncf_graph )
120146
121147 def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
122148 nncf_graph = nncf_fx .nncf_graph_builder .GraphConverter .create_nncf_graph (model )
123- quantization_setup = self .get_nncf_quantization_setup (model , nncf_graph )
124-
149+
125150 graph = model .graph
126151 node_vs_torch_annotation : DefaultDict [torch .fx .Node , QuantizationAnnotation ] = (
127152 defaultdict (QuantizationAnnotation )
128153 )
154+ # Serperate into annotation for quantize and compress
155+ if (self .mode in self .wc_modes ):
156+ self ._algo .set_backend_entity (model )
157+ nodes_to_compress = self ._algo .get_nodes_to_compress (nncf_graph )
158+ for node in nodes_to_compress :
159+ quantization_insertion_point = quantization .quantizer_setup .WeightQuantizationInsertionPoint (target_node_name = node .node_name )
160+ group_size = self ._algo ._group_size
161+ num_bits = 4 if self .mode in [QuantizationMode .INT4_SYM_WC ,QuantizationMode .INT4_ASYM_WC ] else 8
162+ qmode = QuantizationScheme .SYMMETRIC if self .mode in [QuantizationMode .INT4_SYM_WC ,QuantizationMode .INT8_SYM_WC ] else QuantizationScheme .ASYMMETRIC
163+ nncf_qconfig = QuantizerConfig (num_bits = num_bits , mode = qmode )
164+ qp = quantization .quantizer_setup .SingleConfigQuantizationPoint (qip = quantization_insertion_point , qconfig = nncf_qconfig , directly_quantized_operator_node_names = [node ])
165+ edge_or_node , annotation = self ._get_edge_or_node_and_annotation (
166+ graph , nncf_graph , qp , node_vs_torch_annotation
167+ )
168+ qspec : QuantizationSpecBase = self ._get_torch_ao_qspec_from_nncf_config (qp , group_size = group_size , weights_only = True )
169+ self ._fill_torch_ao_annotation (edge_or_node , qspec , annotation )
170+ else :
171+ quantization_setup = self .get_nncf_quantization_setup (model , nncf_graph )
129172
130- for qp in quantization_setup .quantization_points .values ():
131- edge_or_node , annotation = self ._get_edge_or_node_and_annotation (
132- graph , nncf_graph , qp , node_vs_torch_annotation
133- )
134- qspec : QuantizationSpecBase = self ._get_torch_ao_qspec_from_qp (qp )
135- self ._fill_torch_ao_annotation (edge_or_node , qspec , annotation )
173+ for qp in quantization_setup .quantization_points .values ():
174+ edge_or_node , annotation = self ._get_edge_or_node_and_annotation (
175+ graph , nncf_graph , qp , node_vs_torch_annotation
176+ )
177+ qspec : QuantizationSpecBase = self ._get_torch_ao_qspec_from_nncf_config (qp )
178+ self ._fill_torch_ao_annotation (edge_or_node , qspec , annotation )
136179
137- for quantizer_ids in quantization_setup .unified_scale_groups .values ():
180+ for quantizer_ids in quantization_setup .unified_scale_groups .values ():
138181
139- root_quantizer_id = self ._get_unified_scales_root_quantizer_id (
140- nncf_graph , quantizer_ids , quantization_setup
141- )
142- root_qp = quantization_setup .quantization_points [root_quantizer_id ]
182+ root_quantizer_id = self ._get_unified_scales_root_quantizer_id (
183+ nncf_graph , quantizer_ids , quantization_setup
184+ )
185+ root_qp = quantization_setup .quantization_points [root_quantizer_id ]
143186
144- if any (
145- root_qp .qconfig != quantization_setup .quantization_points [q_id ].qconfig
146- for q_id in quantizer_ids
147- ):
148- qps = [
149- quantization_setup .quantization_points [q_id ]
187+ if any (
188+ root_qp .qconfig != quantization_setup .quantization_points [q_id ].qconfig
150189 for q_id in quantizer_ids
151- ]
152- msg = (
153- "Different quantization configs are set to one unified scale group:"
154- f"{ [(qp .insertion_point .__dict__ , str (qp .qconfig )) for qp in qps ]} "
190+ ):
191+ qps = [
192+ quantization_setup .quantization_points [q_id ]
193+ for q_id in quantizer_ids
194+ ]
195+ msg = (
196+ "Different quantization configs are set to one unified scale group:"
197+ f"{ [(qp .insertion_point .__dict__ , str (qp .qconfig )) for qp in qps ]} "
198+ )
199+ raise nncf .InternalError (msg )
200+
201+ root_target_node = nncf_fx .node_utils .get_graph_node_by_name (
202+ graph , root_qp .insertion_point .target_node_name
203+ )
204+ root_edge_or_node = self ._get_edge_or_node (
205+ root_target_node , root_qp , nncf_graph
155206 )
156- raise nncf .InternalError (msg )
157-
158- root_target_node = nncf_fx .node_utils .get_graph_node_by_name (
159- graph , root_qp .insertion_point .target_node_name
160- )
161- root_edge_or_node = self ._get_edge_or_node (
162- root_target_node , root_qp , nncf_graph
163- )
164207
165- for quantizer_id in quantizer_ids :
166- if quantizer_id == root_quantizer_id :
167- continue
208+ for quantizer_id in quantizer_ids :
209+ if quantizer_id == root_quantizer_id :
210+ continue
168211
169- qspec = SharedQuantizationSpec (root_edge_or_node )
170- qp = quantization_setup .quantization_points [quantizer_id ]
171- edge_or_node , annotation = self ._get_edge_or_node_and_annotation (
172- graph , nncf_graph , qp , node_vs_torch_annotation
173- )
174- self ._fill_torch_ao_annotation (edge_or_node , qspec , annotation )
212+ qspec = SharedQuantizationSpec (root_edge_or_node )
213+ qp = quantization_setup .quantization_points [quantizer_id ]
214+ edge_or_node , annotation = self ._get_edge_or_node_and_annotation (
215+ graph , nncf_graph , qp , node_vs_torch_annotation
216+ )
217+ self ._fill_torch_ao_annotation (edge_or_node , qspec , annotation )
175218
176219 for node , annotation in node_vs_torch_annotation .items ():
177220 assert QUANT_ANNOTATION_KEY not in node .meta
@@ -295,8 +338,8 @@ def _fill_torch_ao_annotation(
295338 annotation_to_update .input_qspec_map [edge_or_node [0 ]] = qspec
296339
297340 @staticmethod
298- def _get_torch_ao_qspec_from_qp (
299- qp : quantization .quantizer_setup .QuantizationPointBase ,
341+ def _get_torch_ao_qspec_from_nncf_config (
342+ qp : quantization .quantizer_setup .QuantizationPointBase , group_size = - 1 , weights_only = False
300343 ) -> QuantizationSpec :
301344 """
302345 Retrieves the quantization configuration from the given quantization point and
@@ -307,11 +350,10 @@ def _get_torch_ao_qspec_from_qp(
307350 """
308351 # Eps value is copied from nncf/torch/quantization/layers.py
309352 extra_args = {"eps" : 1e-16 }
310- qconfig = qp .qconfig
311353 is_weight = qp .is_weight_quantization_point ()
354+ qconfig = qp .qconfig
312355
313356 observer : Type [UniformQuantizationObserverBase ]
314-
315357 if qconfig .per_channel :
316358 torch_qscheme = (
317359 torch .per_channel_symmetric
@@ -325,11 +367,27 @@ def _get_torch_ao_qspec_from_qp(
325367 else torch .per_tensor_affine
326368 )
327369 if is_weight :
328- observer = PerChannelMinMaxObserver
329- quant_min = - 128
330- quant_max = 127
331- dtype = torch .int8
332- channel_axis = 0
370+ mapping_type = MappingType .SYMMETRIC if qconfig .mode == QuantizationScheme .SYMMETRIC else MappingType .ASYMMETRIC
371+ if qconfig .num_bits == 4 :
372+ extra_args ["mapping_type" ] = mapping_type
373+ extra_args ["target_dtype" ] = torch .int8
374+ extra_args ["granularity" ] = PerGroup (group_size = group_size )
375+ observer = PTPerBlockParamObserver
376+ quant_min = - 8
377+ quant_max = 7
378+ dtype = torch .int8
379+ channel_axis = 0
380+ elif qconfig .num_bits == 8 :
381+ observer = NNCFInt8observer if weights_only else PerChannelMinMaxObserver
382+ quant_min = - 128
383+ quant_max = 127
384+ dtype = torch .int8
385+ channel_axis = 0
386+ torch_qscheme = (
387+ torch .per_channel_symmetric
388+ if qconfig .mode is quantization .structs .QuantizationScheme .SYMMETRIC
389+ else torch .per_channel_affine
390+ )
333391 else :
334392 observer = (
335393 HistogramObserver
0 commit comments