Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647808740
  • Loading branch information
Google AI Edge authored and junjiang-lab committed Jun 28, 2024
1 parent 133b893 commit 648527c
Show file tree
Hide file tree
Showing 15 changed files with 337 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def init_qsvs(
"""

# Set min/max to 0/6 to help stablize the calibration process
# http://google3/third_party/py/mojax/operation_manager/algorithms/ptq_static_range.py?l=297
init_min_val, init_max_val = 0.0, 6.0
op_qsvs = {}

Expand Down Expand Up @@ -253,7 +252,7 @@ def min_max_calibrate(
inputs_to_ignore: Optional[list[int]] = None,
outputs_to_ignore: Optional[list[int]] = None,
) -> dict[str, qtyping.QSV]:
"""Collect quantization statistics variable (QSV, e.g., scale/zero_point) for the op.
"""Collect quantization statistics variable (QSV, e.g., min/max) for the op.
Args:
tfl_op: the tfl operation.
Expand Down
13 changes: 7 additions & 6 deletions ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Quantization Calibration."""

from collections.abc import Iterable
from typing import Any, Optional
import copy
from typing import Any

from absl import logging

Expand All @@ -19,11 +20,11 @@ class Calibrator:

def __init__(
self,
float_tflite_path: str,
float_tflite: str | bytearray,
):
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite_path)
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)
self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
float_tflite_path
float_tflite
)
# Tensor name to tensor content.
self._tensor_content_map: dict[str, Any] = {}
Expand All @@ -35,7 +36,7 @@ def calibrate(
self,
calibration_dataset: Iterable[_SignatureInput],
model_recipe_manager: recipe_manager.RecipeManager,
signature_key: Optional[str] = None,
signature_key: str | None = None,
) -> None:
"""Calibrates the model using the given dataset for a model signature.
Expand Down Expand Up @@ -134,7 +135,7 @@ def load_model_qsvs(self, model_qsvs: dict[str, qtyping.QSV]) -> None:
Args:
model_qsvs: A dictionary of tensor name to QSV.
"""
self._model_qsvs = model_qsvs
self._model_qsvs = copy.deepcopy(model_qsvs)

def _update_qsvs(self, op_qsvs: dict[str, qtyping.QSV]):
"""Update the model qsvs with the new values.
Expand Down
67 changes: 51 additions & 16 deletions ai_edge_quantizer/examples/mnist_toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import os
import random

from absl import app
from absl import flags
Expand Down Expand Up @@ -37,6 +38,26 @@
'/tmp/',
'Path to save the quantized model and recipe.',
)
_QUANTIZATION_MODE = flags.DEFINE_enum(
'quantization_mode',
'weight_only',
['weight_only', 'drq', 'a8w8', 'a16w8'],
'How to quantize the model (e.g., weight_only, drq, a8w8, a16w8).',
)


def _get_calibration_data(
num_samples: int = 256,
) -> list[dict[str, np.ndarray]]:
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0 # Normalize pixel values to 0-1.
x_train = x_train.astype(np.float32)
x_train = x_train.reshape([-1, 28, 28, 1])
calibration_data = []
for _ in range(num_samples):
sample = random.choice(x_train)
calibration_data.append({'conv2d_input': sample.reshape([-1, 28, 28, 1])})
return calibration_data


def read_img(img_path: str):
Expand All @@ -63,31 +84,45 @@ def read_img(img_path: str):

def quantize(
float_model_path: str,
execution_mode: _OpExecutionMode = _OpExecutionMode.WEIGHT_ONLY,
quantization_mode: str,
) -> quantizer.QuantizationResult:
"""Quantize the float model.
Args:
float_model_path: Path to the float model.
execution_mode: Execution mode for the quantized model.
quantization_mode: How to quantize the model (e.g., weight_only, drq).
Returns:
QuantResult: quantization result
"""
if quantization_mode == 'weight_only':
recipe_path = test_utils.get_path_to_datafile(
'../tests/recipes/conv_fc_mnist_weight_only_recipe.json'
)
elif quantization_mode == 'drq':
recipe_path = test_utils.get_path_to_datafile(
'../tests/recipes/conv_fc_mnist_drq_recipe.json'
)
elif quantization_mode == 'a8w8':
recipe_path = test_utils.get_path_to_datafile(
'../tests/recipes/conv_fc_mnist_a8w8_recipe.json'
)
elif quantization_mode == 'a16w8':
recipe_path = test_utils.get_path_to_datafile(
'../tests/recipes/conv_fc_mnist_a16w8_recipe.json'
)
else:
raise ValueError(
'Invalid quantization mode. Only weight_only, drq, a8w8, a16w8 are'
' supported.'
)

qt = quantizer.Quantizer(float_model_path)
qt.update_quantization_recipe(
regex='.*',
operation_name=_OpName.FULLY_CONNECTED,
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=8,
symmetric=False,
channel_wise=True,
),
execution_mode=execution_mode,
),
)
return qt.quantize()
qt.load_quantization_recipe(recipe_path)
calibration_result = None
if qt.need_calibration:
calibration_result = qt.calibrate(_get_calibration_data())
return qt.quantize(calibration_result)


def inference(quantized_tflite: bytes, image_path: str) -> np.ndarray:
Expand Down Expand Up @@ -116,7 +151,7 @@ def main(_) -> None:
)
if not os.path.exists(_IMG_PATH.value):
raise ValueError('Image file does not exist. Please check the image path.')
quant_result = quantize(_FLOAT_MODEL_PATH.value, _OpExecutionMode.WEIGHT_ONLY)
quant_result = quantize(_FLOAT_MODEL_PATH.value, _QUANTIZATION_MODE.value)
category_probabilities = inference(
quant_result.quantized_model, _IMG_PATH.value
)
Expand Down
1 change: 0 additions & 1 deletion ai_edge_quantizer/model_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_modify_model(self):
},
'execution_mode': qtyping.OpExecutionMode.WEIGHT_ONLY,
},
'override_algorithm': True,
},
]
recipe_manager_instance.load_quantization_recipe(global_recipe)
Expand Down
12 changes: 6 additions & 6 deletions ai_edge_quantizer/params_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,13 @@ def _check_buffer_sharing(self) -> None:
"""
for tensors in self.buffer_to_tensors.values():
first_tensor = tensors[0]
first_tensor_params = self.model_quant_results[
tfl_flatbuffer_utils.get_tensor_name(first_tensor)
]
first_tensor_params = self.model_quant_results.get(
tfl_flatbuffer_utils.get_tensor_name(first_tensor), None
)
for tensor in tensors[1:]:
tensor_params = self.model_quant_results[
tfl_flatbuffer_utils.get_tensor_name(tensor)
]
tensor_params = self.model_quant_results.get(
tfl_flatbuffer_utils.get_tensor_name(tensor), None
)
error_msg = (
f'The tensors {first_tensor.name} and {tensor.name} do not have the'
' same quantization parameters even though they share the same'
Expand Down
5 changes: 0 additions & 5 deletions ai_edge_quantizer/params_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def test_generate_config_global(self):
},
'execution_mode': _OpExecutionMode.WEIGHT_ONLY,
},
'override_algorithm': True,
},
]
self._recipe_manager.load_quantization_recipe(global_recipe)
Expand Down Expand Up @@ -262,7 +261,6 @@ def test_generate_config_selective(self):
},
'execution_mode': _OpExecutionMode.DRQ,
},
'override_algorithm': True,
},
{
'regex': '.*/dense_1/.*',
Expand All @@ -277,7 +275,6 @@ def test_generate_config_selective(self):
},
'execution_mode': _OpExecutionMode.WEIGHT_ONLY,
},
'override_algorithm': True,
},
]
self._recipe_manager.load_quantization_recipe(selective_quantization_recipe)
Expand Down Expand Up @@ -328,7 +325,6 @@ def test_generate_config_edge_cases(self):
},
'execution_mode': _OpExecutionMode.DRQ,
},
'override_algorithm': True,
},
# Scope that does not exist in the model.
{
Expand All @@ -343,7 +339,6 @@ def test_generate_config_edge_cases(self):
},
'execution_mode': _OpExecutionMode.WEIGHT_ONLY,
},
'override_algorithm': True,
},
]
self._recipe_manager.load_quantization_recipe(selective_quantization_recipe)
Expand Down
61 changes: 51 additions & 10 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from typing import Any, Optional, Union
from ai_edge_quantizer import algorithm_manager
from ai_edge_quantizer import calibrator
from ai_edge_quantizer import model_modifier
from ai_edge_quantizer import model_validator
from ai_edge_quantizer import params_generator
Expand All @@ -24,6 +25,7 @@
_TensorQuantizationConfig = qtyping.TensorQuantizationConfig
_TensorTransformationParams = dict[str, qtyping.TensorTransformationParams]
_SignatureInput = dict[str, Any] # input_argument_name -> tensor_value.
_CalibrationResult = dict[str, qtyping.QSV]


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -126,7 +128,6 @@ def update_quantization_recipe(
operation_name: _TFLOpName,
op_config: Optional[_OpQuantizationConfig] = None,
algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT,
override_algorithm: bool = True,
):
"""Adds a quantization configuration to the recipe.
Expand All @@ -146,25 +147,62 @@ def update_quantization_recipe(
default configuration. None or empty dict means the default
configuration will be used.
algorithm_key: Algorithm key to be applied.
override_algorithm: Flag to check if this rule overrides the previously
matched rule with different algorithm key.
"""
self._recipe_manager.add_quantization_config(
regex, operation_name, op_config, algorithm_key, override_algorithm
regex, operation_name, op_config, algorithm_key
)

def quantize(self) -> QuantizationResult:
@property
def need_calibration(self) -> bool:
"""Checks if the current recipe needs calibration."""
return self._recipe_manager.need_calibration()

def calibrate(
self,
calibration_data: Iterable[_SignatureInput],
signature_key: Optional[str] = None,
previous_calibration_result: Optional[_CalibrationResult] = None,
) -> _CalibrationResult:
"""Calibrates the float model (required by static range quantization).
Args:
calibration_data: Calibration data for a model signature.
signature_key: The signature key to be used for invoking the models. If
the model doesn't have a signature key, this can be set to None.
previous_calibration_result: Previous calibration result to be loaded. The
calibration process will be resumed from the previous result.
Returns:
Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}).
"""
if not self.need_calibration:
return {}

calib = calibrator.Calibrator(self.float_model)
if previous_calibration_result is not None:
calib.load_model_qsvs(previous_calibration_result)
calib.calibrate(calibration_data, self._recipe_manager, signature_key)
return calib.get_model_qsvs()

def quantize(
self, calibration_result: Optional[_CalibrationResult] = None
) -> QuantizationResult:
"""Quantizes the float model.
Args:
calibration_result: Calibration result to be used for quantization (if
needed, check with self.need_calibration).
Returns:
Quantization result.
Raises:
RuntimeError: If no quantization recipe is loaded.
RuntimeError: If quantization recipe is empty.
"""

if not self.get_quantization_recipe():
raise RuntimeError('Can not quantize without a quantization recipe.')
quant_params = self._get_quantization_params()
quant_params = self._get_quantization_params(calibration_result)
quantized_model = self._get_quantized_model(quant_params)
self._result = QuantizationResult(
self.get_quantization_recipe(), quantized_model
Expand Down Expand Up @@ -216,7 +254,6 @@ def compare(
self.float_model,
self._result.quantized_model,
signature_test_data,
quantize_target_input=False, # will be removed later.
compare_fn=validation_utils.get_validation_func(error_metrics),
signature_key=signature_key,
)
Expand Down Expand Up @@ -245,18 +282,22 @@ def save_comparison_result(
output_file_handle.write(json_object)

def _get_quantization_params(
self,
self, calibration_result: Optional[_CalibrationResult] = None
) -> _TensorTransformationParams:
"""Gets the quantization parameters.
Args:
calibration_result: Calibration result to be used for quantization (if
needed, check with self.need_calibration).
Returns:
A dictionary containing the quantization parameters.
"""
params_generator_instance = params_generator.ParamsGenerator(
self.float_model
)
return params_generator_instance.generate_quantization_parameters(
self._recipe_manager
self._recipe_manager, calibration_result
)

def _get_quantized_model(
Expand Down
Loading

0 comments on commit 648527c

Please sign in to comment.