-
Notifications
You must be signed in to change notification settings - Fork 0
[OVQuantizer] Apply Fixes and Integrate into the Llama Example Workflow #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
30a1a25
4cc7694
5da40a5
9e65a7e
53e0f4c
bf95930
2d4bec7
0a2e361
c8ea777
a6b605f
9614fc4
45007cf
9d49414
d6727cf
4a0a781
f6a1ee3
d285fcc
4e66df1
e850e41
204043f
ae6b089
a6f036c
2de5693
0e10f28
05f5a92
fbe0e21
6bff1cd
d744ae9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,3 @@ | ||
| from .quantizer import OpenVINOQuantizer, quantize_model | ||
| from .quantizer import OpenVINOQuantizer, quantize_model, QuantizationMode | ||
|
|
||
| __all__ = ["OpenVINOQuantizer", "quantize_model"] | ||
| __all__ = ["OpenVINOQuantizer", "quantize_model", "QuantizationMode"] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,199 @@ | ||||||||||||||||
| # Copyright (c) Intel Corporation | ||||||||||||||||
| # | ||||||||||||||||
| # Licensed under the BSD License (the "License"); you may not use this file | ||||||||||||||||
| # except in compliance with the License. See the license file found in the | ||||||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||||||
|
|
||||||||||||||||
| # mypy: disable-error-code=import-not-found | ||||||||||||||||
|
|
||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||
| from typing import Optional, Tuple | ||||||||||||||||
|
|
||||||||||||||||
| import torch | ||||||||||||||||
|
|
||||||||||||||||
| from nncf.experimental.torch.fx.node_utils import ( # type: ignore[import-untyped] | ||||||||||||||||
| get_tensor_constant_from_node, | ||||||||||||||||
| ) | ||||||||||||||||
| from nncf.experimental.torch.fx.transformations import ( # type: ignore[import-untyped] | ||||||||||||||||
| constant_update_fn, | ||||||||||||||||
| module_insertion_transformation_builder, | ||||||||||||||||
| ) | ||||||||||||||||
| from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped] | ||||||||||||||||
| WeightCompressionParameters, | ||||||||||||||||
| ) | ||||||||||||||||
| from nncf.quantization.algorithms.weight_compression.weight_lowering import ( # type: ignore[import-untyped] | ||||||||||||||||
| do_integer_quantization, | ||||||||||||||||
| ) | ||||||||||||||||
| from nncf.tensor.tensor import Tensor as NNCFTensor # type: ignore[import-untyped] | ||||||||||||||||
| from nncf.torch.graph.transformations.commands import ( # type: ignore[import-untyped] | ||||||||||||||||
| PTTargetPoint, | ||||||||||||||||
| TargetType, | ||||||||||||||||
| ) | ||||||||||||||||
| from nncf.torch.quantization.layers import ( # type: ignore[import-untyped] | ||||||||||||||||
| BaseWeightsDecompressor, | ||||||||||||||||
| INT4AsymmetricWeightsDecompressor, | ||||||||||||||||
| INT4SymmetricWeightsDecompressor, | ||||||||||||||||
| INT8AsymmetricWeightsDecompressor, | ||||||||||||||||
| INT8SymmetricWeightsDecompressor, | ||||||||||||||||
| ) | ||||||||||||||||
| from torchao.quantization.pt2e import ObserverBase | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class WeightObserverBase(ObserverBase, ABC): | ||||||||||||||||
| """ | ||||||||||||||||
| Base implementation of an NNCF observer that defines the rules for compressing layer weights into the OpenVINO representation. | ||||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
| def __init__( | ||||||||||||||||
| self, | ||||||||||||||||
| wc_param: WeightCompressionParameters, | ||||||||||||||||
| dtype: torch.dtype, | ||||||||||||||||
| **kwargs, | ||||||||||||||||
| ) -> None: | ||||||||||||||||
| """ | ||||||||||||||||
| :param wc_param: Weight compression parameters container. | ||||||||||||||||
| :param dtype: target dtype for the quantization. | ||||||||||||||||
| """ | ||||||||||||||||
| super().__init__(dtype=dtype, is_dynamic=False) | ||||||||||||||||
| self.wc_param = wc_param | ||||||||||||||||
|
|
||||||||||||||||
| def calculate_qparams( # type: ignore[override] | ||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this method be private? Or is it used somewhere in the base class?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||||||||||||
| self, | ||||||||||||||||
| weight: torch.Tensor, | ||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | ||||||||||||||||
| """ | ||||||||||||||||
| Calculates quantization parameters: quantized weight, quantization scale and quantization zero point. | ||||||||||||||||
|
|
||||||||||||||||
| :param weight: FP weight to be used for calculating qparams. | ||||||||||||||||
| :return: A tuple containing the quantized weight, quantization scale and quantization zero point. | ||||||||||||||||
| """ | ||||||||||||||||
| wc_param = self.get_wc_param() | ||||||||||||||||
|
||||||||||||||||
| wc_param = self.get_wc_param() | |
| wc_param = self._wc_param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the observer node is a ObserverBase subclass? If so, I would put ObserverBase as the typehint for the observer_node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Observer node is a torch.fx.Node with observer base class as its target
anzr299 marked this conversation as resolved.
Show resolved
Hide resolved
anzr299 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def get_wc_param(self) -> WeightCompressionParameters: | |
| """ | |
| Returns a respective NNCF Weight Compression Config. | |
| :return: Weight compression config with the compression information such as qmode, group_size etc. | |
| """ | |
| return self.wc_param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we invert the condition here? IMHO is None is clearer than is not None :)
| if zero_point is not None: | |
| if zero_point is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| else: | |
| return INT4SymmetricWeightsDecompressor( | |
| scale, q_weight.shape, original_weight.shape, original_weight.dtype | |
| ) | |
| return INT4SymmetricWeightsDecompressor( | |
| scale, q_weight.shape, original_weight.shape, original_weight.dtype | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same comment as above regarding the condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| else: | |
| return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype) | |
| return INT8SymmetricWeightsDecompressor(scale, original_weight.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done