Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
30a1a25
openvino quantizer refactored
anzr299 Aug 26, 2025
4cc7694
fixes
anzr299 Aug 26, 2025
5da40a5
support all_layers, backup mode in OVQuantizer
anzr299 Aug 27, 2025
9e65a7e
clean up and use new nncf method for obtaining compression parameters
anzr299 Aug 27, 2025
53e0f4c
review changes & update method names according to wc algo
anzr299 Sep 1, 2025
bf95930
review changes
anzr299 Sep 1, 2025
2d4bec7
review changes
anzr299 Sep 1, 2025
0a2e361
Update export_llama_lib.py
anzr299 Sep 3, 2025
c8ea777
use new transformations
anzr299 Sep 6, 2025
a6b605f
add comment for manual MP allocation
anzr299 Sep 6, 2025
9614fc4
remove nncf_compression from export llama lib
anzr299 Sep 6, 2025
45007cf
change pt2e quantize flag to use openvino_4wo instead of openvino_8da…
anzr299 Sep 6, 2025
9d49414
follow up to last commit
anzr299 Sep 6, 2025
d6727cf
update quantizer lib with openvino_4wo
anzr299 Sep 6, 2025
4a0a781
split qspec function into 2 parts; 1 for WC and other for PTQ qspecs
anzr299 Sep 6, 2025
f6a1ee3
micro fix
anzr299 Sep 8, 2025
d285fcc
udpate mixed precision layers for higher accuracy. Change INT4 mode t…
anzr299 Sep 8, 2025
4e66df1
Apply suggestions from code review
anzr299 Sep 8, 2025
e850e41
Review changes
anzr299 Sep 8, 2025
204043f
review changes in quantizer
anzr299 Sep 8, 2025
ae6b089
revert extra args changes
anzr299 Sep 8, 2025
a6f036c
Merge branch 'openvino_llama_support' of https://github.com/anzr299/e…
anzr299 Sep 9, 2025
2de5693
precommit fixes
anzr299 Sep 9, 2025
0e10f28
revert _calculate_qparams back to calculate_qparams
anzr299 Sep 9, 2025
05f5a92
remove manual ignored nodes
anzr299 Sep 10, 2025
fbe0e21
add ratio to quantizer initialization
anzr299 Sep 10, 2025
6bff1cd
Update export_llama_lib.py
anzr299 Sep 11, 2025
d744ae9
Update quantizer_lib.py
anzr299 Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/openvino/quantizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .quantizer import OpenVINOQuantizer, quantize_model
from .quantizer import OpenVINOQuantizer, quantize_model, QuantizationMode

__all__ = ["OpenVINOQuantizer", "quantize_model"]
__all__ = ["OpenVINOQuantizer", "quantize_model", "QuantizationMode"]
198 changes: 198 additions & 0 deletions backends/openvino/quantizer/observers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) Intel Corporation
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file found in the
# LICENSE file in the root directory of this source tree.

# mypy: disable-error-code=import-not-found

from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch

from nncf.experimental.torch.fx.node_utils import ( # type: ignore[import-untyped]
get_tensor_constant_from_node,
)
from nncf.experimental.torch.fx.transformations import ( # type: ignore[import-untyped]
constant_update,
module_insertion,
node_removal,
)
from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped]
WeightCompressionParameters,
)
from nncf.quantization.algorithms.weight_compression.weight_lowering import ( # type: ignore[import-untyped]
do_integer_quantization,
)
from nncf.tensor.tensor import Tensor as NNCFTensor # type: ignore[import-untyped]
from nncf.torch.graph.transformations.commands import ( # type: ignore[import-untyped]
PTTargetPoint,
TargetType,
)
from nncf.torch.quantization.layers import ( # type: ignore[import-untyped]
BaseWeightsDecompressor,
INT4AsymmetricWeightsDecompressor,
INT4SymmetricWeightsDecompressor,
INT8AsymmetricWeightsDecompressor,
INT8SymmetricWeightsDecompressor,
)
from torchao.quantization.pt2e import ObserverBase


class WeightObserverBase(ObserverBase, ABC):
"""
Base implementation of an NNCF observer that defines the rules for compressing layer weights into the OpenVINO representation.
"""

def __init__(
self,
wc_param: WeightCompressionParameters,
dtype: torch.dtype,
**kwargs,
) -> None:
"""
:param wc_param: Weight compression parameters container.
:param dtype: target dtype for the quantization.
"""
super().__init__(dtype=dtype, is_dynamic=False)
self.wc_param = wc_param

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.wc_param = wc_param
self._wc_param = wc_param

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def calculate_qparams( # type: ignore[override]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this method be private? Or is it used somewhere in the base class?

Suggested change
def calculate_qparams( # type: ignore[override]
def _calculate_qparams( # type: ignore[override]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self,
weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Calculates quantization parameters: quantized weight, quantization scale and quantization zero point.

:param weight: FP weight to be used for calculating qparams.
:return: A tuple containing the quantized weight, quantization scale and quantization zero point.
"""
wc_param = self.get_wc_param()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wc_param = self.get_wc_param()
wc_param = self._wc_param

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

wc_config = wc_param.compression_config
reduction_axes = wc_param.reduction_axes
q_weight, scale, zp = do_integer_quantization(
NNCFTensor(weight), wc_config, reduction_axes=reduction_axes
)
zp = zp.data if zp is not None else None
return q_weight.data, scale.data, zp

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x

def convert(
self, model: torch.fx.GraphModule, observer_node: torch.fx.Node

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the observer node is a ObserverBase subclass? If so, I would put ObserverBase as the typehint for the observer_node

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observer node is a torch.fx.Node with observer base class as its target

) -> None:
"""
Replaces the given observer node from the given model with a quantized
weight and a OpenVINO specific decompression module.

:param model: A `torch.fx.GraphModule` representing the statically traced model
with observer nodes attached and calibrated.
:param observer_node: The `torch.fx.Node` corresponding to the observer module for
the weight that is being transformed into a compressed representation.
"""
weight_node = observer_node.args[0]
original_weight = get_tensor_constant_from_node(weight_node, model)
q_weight, scale, zero_point = self.calculate_qparams(
original_weight
)

decompressor = self._create_decompressor(
scale, zero_point, q_weight, original_weight
)
packed_q_weight = decompressor.pack_weight(q_weight)

# Weight port id is 0 since observer is inserted for a single weight only.
constant_update(model, observer_node, packed_q_weight, input_port_id=0)

compressed_weight_name = observer_node.all_input_nodes[0].name
decompressor_suffix = "_".join(
compressed_weight_name.replace(".", "_").split("_")[:-2]
)
decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}"

module_insertion(
model,
decompressor,
[
PTTargetPoint(
TargetType.OPERATOR_POST_HOOK,
target_node_name=compressed_weight_name,
)
],
decompressor_name,
)
node_removal(model, observer_node, 0)

@abstractmethod
def _create_decompressor(
self,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
q_weight: torch.Tensor,
original_weight: torch.Tensor,
) -> BaseWeightsDecompressor:
"""
Returns a respective NNCF decompressor for different types of quantization.

:param scale: Calculated scale quantization parameter.
:param zero_point: Calculated zero_point quantization parameter.
:param q_weight: Calculated quantized weight.
:param original_weight: FP weight.
:return: NNCF observer according to the qmode which creates the decompression subgraph supported by OpenVINO.
"""

def get_wc_param(self) -> WeightCompressionParameters:
"""
Returns a respective NNCF Weight Compression Config.

:return: Weight compression config with the compression information such as qmode, group_size etc.
"""
return self.wc_param

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_wc_param(self) -> WeightCompressionParameters:
"""
Returns a respective NNCF Weight Compression Config.
:return: Weight compression config with the compression information such as qmode, group_size etc.
"""
return self.wc_param

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


class INT4WeightObserver(WeightObserverBase):
"""
OpenVINO INT4 Weight Compression observer.
"""

def _create_decompressor(
self,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
q_weight: torch.Tensor,
original_weight: torch.Tensor,
) -> BaseWeightsDecompressor:
if zero_point is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we invert the condition here? IMHO is None is clearer than is not None :)

Suggested change
if zero_point is not None:
if zero_point is None:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return INT4AsymmetricWeightsDecompressor(
scale,
zero_point,
q_weight.shape,
original_weight.shape,
original_weight.dtype,
)
else:
return INT4SymmetricWeightsDecompressor(
scale, q_weight.shape, original_weight.shape, original_weight.dtype
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
return INT4SymmetricWeightsDecompressor(
scale, q_weight.shape, original_weight.shape, original_weight.dtype
)
return INT4SymmetricWeightsDecompressor(
scale, q_weight.shape, original_weight.shape, original_weight.dtype
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



class INT8WeightObserver(WeightObserverBase):
"""
OpenVINO INT8 Weight Compression per channel observer.
"""

def _create_decompressor(
self,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
q_weight: torch.Tensor,
original_weight: torch.Tensor,
) -> BaseWeightsDecompressor:
if zero_point is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same comment as above regarding the condition

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return INT8AsymmetricWeightsDecompressor(
scale, zero_point, original_weight.dtype
)
else:
return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype)
return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


176 changes: 0 additions & 176 deletions backends/openvino/quantizer/observers/nncf_observers.py

This file was deleted.

Loading