diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py index e164469..f412402 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py @@ -215,7 +215,6 @@ def init_qsvs( """ # Set min/max to 0/6 to help stablize the calibration process - # http://google3/third_party/py/mojax/operation_manager/algorithms/ptq_static_range.py?l=297 init_min_val, init_max_val = 0.0, 6.0 op_qsvs = {} @@ -253,7 +252,7 @@ def min_max_calibrate( inputs_to_ignore: Optional[list[int]] = None, outputs_to_ignore: Optional[list[int]] = None, ) -> dict[str, qtyping.QSV]: - """Collect quantization statistics variable (QSV, e.g., scale/zero_point) for the op. + """Collect quantization statistics variable (QSV, e.g., min/max) for the op. Args: tfl_op: the tfl operation. diff --git a/ai_edge_quantizer/calibrator.py b/ai_edge_quantizer/calibrator.py index 6521497..3eab926 100644 --- a/ai_edge_quantizer/calibrator.py +++ b/ai_edge_quantizer/calibrator.py @@ -1,7 +1,8 @@ """Quantization Calibration.""" from collections.abc import Iterable -from typing import Any, Optional +import copy +from typing import Any from absl import logging @@ -19,11 +20,11 @@ class Calibrator: def __init__( self, - float_tflite_path: str, + float_tflite: str | bytearray, ): - self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite_path) + self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite) self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter( - float_tflite_path + float_tflite ) # Tensor name to tensor content. self._tensor_content_map: dict[str, Any] = {} @@ -35,7 +36,7 @@ def calibrate( self, calibration_dataset: Iterable[_SignatureInput], model_recipe_manager: recipe_manager.RecipeManager, - signature_key: Optional[str] = None, + signature_key: str | None = None, ) -> None: """Calibrates the model using the given dataset for a model signature. @@ -134,7 +135,7 @@ def load_model_qsvs(self, model_qsvs: dict[str, qtyping.QSV]) -> None: Args: model_qsvs: A dictionary of tensor name to QSV. """ - self._model_qsvs = model_qsvs + self._model_qsvs = copy.deepcopy(model_qsvs) def _update_qsvs(self, op_qsvs: dict[str, qtyping.QSV]): """Update the model qsvs with the new values. diff --git a/ai_edge_quantizer/examples/mnist_toy_model.py b/ai_edge_quantizer/examples/mnist_toy_model.py index b0e2b7b..3f2f799 100644 --- a/ai_edge_quantizer/examples/mnist_toy_model.py +++ b/ai_edge_quantizer/examples/mnist_toy_model.py @@ -5,6 +5,7 @@ """ import os +import random from absl import app from absl import flags @@ -37,6 +38,26 @@ '/tmp/', 'Path to save the quantized model and recipe.', ) +_QUANTIZATION_MODE = flags.DEFINE_enum( + 'quantization_mode', + 'weight_only', + ['weight_only', 'drq', 'a8w8', 'a16w8'], + 'How to quantize the model (e.g., weight_only, drq, a8w8, a16w8).', +) + + +def _get_calibration_data( + num_samples: int = 256, +) -> list[dict[str, np.ndarray]]: + (x_train, _), _ = tf.keras.datasets.mnist.load_data() + x_train = x_train / 255.0 # Normalize pixel values to 0-1. + x_train = x_train.astype(np.float32) + x_train = x_train.reshape([-1, 28, 28, 1]) + calibration_data = [] + for _ in range(num_samples): + sample = random.choice(x_train) + calibration_data.append({'conv2d_input': sample.reshape([-1, 28, 28, 1])}) + return calibration_data def read_img(img_path: str): @@ -63,31 +84,45 @@ def read_img(img_path: str): def quantize( float_model_path: str, - execution_mode: _OpExecutionMode = _OpExecutionMode.WEIGHT_ONLY, + quantization_mode: str, ) -> quantizer.QuantizationResult: """Quantize the float model. Args: float_model_path: Path to the float model. - execution_mode: Execution mode for the quantized model. + quantization_mode: How to quantize the model (e.g., weight_only, drq). Returns: QuantResult: quantization result """ + if quantization_mode == 'weight_only': + recipe_path = test_utils.get_path_to_datafile( + '../tests/recipes/conv_fc_mnist_weight_only_recipe.json' + ) + elif quantization_mode == 'drq': + recipe_path = test_utils.get_path_to_datafile( + '../tests/recipes/conv_fc_mnist_drq_recipe.json' + ) + elif quantization_mode == 'a8w8': + recipe_path = test_utils.get_path_to_datafile( + '../tests/recipes/conv_fc_mnist_a8w8_recipe.json' + ) + elif quantization_mode == 'a16w8': + recipe_path = test_utils.get_path_to_datafile( + '../tests/recipes/conv_fc_mnist_a16w8_recipe.json' + ) + else: + raise ValueError( + 'Invalid quantization mode. Only weight_only, drq, a8w8, a16w8 are' + ' supported.' + ) + qt = quantizer.Quantizer(float_model_path) - qt.update_quantization_recipe( - regex='.*', - operation_name=_OpName.FULLY_CONNECTED, - op_config=_OpQuantConfig( - weight_tensor_config=_TensorQuantConfig( - num_bits=8, - symmetric=False, - channel_wise=True, - ), - execution_mode=execution_mode, - ), - ) - return qt.quantize() + qt.load_quantization_recipe(recipe_path) + calibration_result = None + if qt.need_calibration: + calibration_result = qt.calibrate(_get_calibration_data()) + return qt.quantize(calibration_result) def inference(quantized_tflite: bytes, image_path: str) -> np.ndarray: @@ -116,7 +151,7 @@ def main(_) -> None: ) if not os.path.exists(_IMG_PATH.value): raise ValueError('Image file does not exist. Please check the image path.') - quant_result = quantize(_FLOAT_MODEL_PATH.value, _OpExecutionMode.WEIGHT_ONLY) + quant_result = quantize(_FLOAT_MODEL_PATH.value, _QUANTIZATION_MODE.value) category_probabilities = inference( quant_result.quantized_model, _IMG_PATH.value ) diff --git a/ai_edge_quantizer/model_modifier_test.py b/ai_edge_quantizer/model_modifier_test.py index e248fd1..25f9b47 100644 --- a/ai_edge_quantizer/model_modifier_test.py +++ b/ai_edge_quantizer/model_modifier_test.py @@ -53,7 +53,6 @@ def test_modify_model(self): }, 'execution_mode': qtyping.OpExecutionMode.WEIGHT_ONLY, }, - 'override_algorithm': True, }, ] recipe_manager_instance.load_quantization_recipe(global_recipe) diff --git a/ai_edge_quantizer/params_generator.py b/ai_edge_quantizer/params_generator.py index 939be0a..369480d 100644 --- a/ai_edge_quantizer/params_generator.py +++ b/ai_edge_quantizer/params_generator.py @@ -215,13 +215,13 @@ def _check_buffer_sharing(self) -> None: """ for tensors in self.buffer_to_tensors.values(): first_tensor = tensors[0] - first_tensor_params = self.model_quant_results[ - tfl_flatbuffer_utils.get_tensor_name(first_tensor) - ] + first_tensor_params = self.model_quant_results.get( + tfl_flatbuffer_utils.get_tensor_name(first_tensor), None + ) for tensor in tensors[1:]: - tensor_params = self.model_quant_results[ - tfl_flatbuffer_utils.get_tensor_name(tensor) - ] + tensor_params = self.model_quant_results.get( + tfl_flatbuffer_utils.get_tensor_name(tensor), None + ) error_msg = ( f'The tensors {first_tensor.name} and {tensor.name} do not have the' ' same quantization parameters even though they share the same' diff --git a/ai_edge_quantizer/params_generator_test.py b/ai_edge_quantizer/params_generator_test.py index 8d1f905..40fcea4 100644 --- a/ai_edge_quantizer/params_generator_test.py +++ b/ai_edge_quantizer/params_generator_test.py @@ -124,7 +124,6 @@ def test_generate_config_global(self): }, 'execution_mode': _OpExecutionMode.WEIGHT_ONLY, }, - 'override_algorithm': True, }, ] self._recipe_manager.load_quantization_recipe(global_recipe) @@ -262,7 +261,6 @@ def test_generate_config_selective(self): }, 'execution_mode': _OpExecutionMode.DRQ, }, - 'override_algorithm': True, }, { 'regex': '.*/dense_1/.*', @@ -277,7 +275,6 @@ def test_generate_config_selective(self): }, 'execution_mode': _OpExecutionMode.WEIGHT_ONLY, }, - 'override_algorithm': True, }, ] self._recipe_manager.load_quantization_recipe(selective_quantization_recipe) @@ -328,7 +325,6 @@ def test_generate_config_edge_cases(self): }, 'execution_mode': _OpExecutionMode.DRQ, }, - 'override_algorithm': True, }, # Scope that does not exist in the model. { @@ -343,7 +339,6 @@ def test_generate_config_edge_cases(self): }, 'execution_mode': _OpExecutionMode.WEIGHT_ONLY, }, - 'override_algorithm': True, }, ] self._recipe_manager.load_quantization_recipe(selective_quantization_recipe) diff --git a/ai_edge_quantizer/quantizer.py b/ai_edge_quantizer/quantizer.py index d9a2368..bfdeefa 100644 --- a/ai_edge_quantizer/quantizer.py +++ b/ai_edge_quantizer/quantizer.py @@ -6,6 +6,7 @@ import os from typing import Any, Optional, Union from ai_edge_quantizer import algorithm_manager +from ai_edge_quantizer import calibrator from ai_edge_quantizer import model_modifier from ai_edge_quantizer import model_validator from ai_edge_quantizer import params_generator @@ -24,6 +25,7 @@ _TensorQuantizationConfig = qtyping.TensorQuantizationConfig _TensorTransformationParams = dict[str, qtyping.TensorTransformationParams] _SignatureInput = dict[str, Any] # input_argument_name -> tensor_value. +_CalibrationResult = dict[str, qtyping.QSV] @dataclasses.dataclass(frozen=True) @@ -126,7 +128,6 @@ def update_quantization_recipe( operation_name: _TFLOpName, op_config: Optional[_OpQuantizationConfig] = None, algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT, - override_algorithm: bool = True, ): """Adds a quantization configuration to the recipe. @@ -146,25 +147,62 @@ def update_quantization_recipe( default configuration. None or empty dict means the default configuration will be used. algorithm_key: Algorithm key to be applied. - override_algorithm: Flag to check if this rule overrides the previously - matched rule with different algorithm key. """ self._recipe_manager.add_quantization_config( - regex, operation_name, op_config, algorithm_key, override_algorithm + regex, operation_name, op_config, algorithm_key ) - def quantize(self) -> QuantizationResult: + @property + def need_calibration(self) -> bool: + """Checks if the current recipe needs calibration.""" + return self._recipe_manager.need_calibration() + + def calibrate( + self, + calibration_data: Iterable[_SignatureInput], + signature_key: Optional[str] = None, + previous_calibration_result: Optional[_CalibrationResult] = None, + ) -> _CalibrationResult: + """Calibrates the float model (required by static range quantization). + + Args: + calibration_data: Calibration data for a model signature. + signature_key: The signature key to be used for invoking the models. If + the model doesn't have a signature key, this can be set to None. + previous_calibration_result: Previous calibration result to be loaded. The + calibration process will be resumed from the previous result. + + Returns: + Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}). + """ + if not self.need_calibration: + return {} + + calib = calibrator.Calibrator(self.float_model) + if previous_calibration_result is not None: + calib.load_model_qsvs(previous_calibration_result) + calib.calibrate(calibration_data, self._recipe_manager, signature_key) + return calib.get_model_qsvs() + + def quantize( + self, calibration_result: Optional[_CalibrationResult] = None + ) -> QuantizationResult: """Quantizes the float model. + Args: + calibration_result: Calibration result to be used for quantization (if + needed, check with self.need_calibration). + Returns: Quantization result. Raises: - RuntimeError: If no quantization recipe is loaded. + RuntimeError: If quantization recipe is empty. """ + if not self.get_quantization_recipe(): raise RuntimeError('Can not quantize without a quantization recipe.') - quant_params = self._get_quantization_params() + quant_params = self._get_quantization_params(calibration_result) quantized_model = self._get_quantized_model(quant_params) self._result = QuantizationResult( self.get_quantization_recipe(), quantized_model @@ -216,7 +254,6 @@ def compare( self.float_model, self._result.quantized_model, signature_test_data, - quantize_target_input=False, # will be removed later. compare_fn=validation_utils.get_validation_func(error_metrics), signature_key=signature_key, ) @@ -245,10 +282,14 @@ def save_comparison_result( output_file_handle.write(json_object) def _get_quantization_params( - self, + self, calibration_result: Optional[_CalibrationResult] = None ) -> _TensorTransformationParams: """Gets the quantization parameters. + Args: + calibration_result: Calibration result to be used for quantization (if + needed, check with self.need_calibration). + Returns: A dictionary containing the quantization parameters. """ @@ -256,7 +297,7 @@ def _get_quantization_params( self.float_model ) return params_generator_instance.generate_quantization_parameters( - self._recipe_manager + self._recipe_manager, calibration_result ) def _get_quantized_model( diff --git a/ai_edge_quantizer/quantizer_test.py b/ai_edge_quantizer/quantizer_test.py index d01b567..f348d79 100644 --- a/ai_edge_quantizer/quantizer_test.py +++ b/ai_edge_quantizer/quantizer_test.py @@ -2,6 +2,10 @@ import json import os + +from absl.testing import parameterized +import numpy as np + from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer @@ -14,9 +18,19 @@ _AlgorithmName = quantizer.AlgorithmName TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('') +_RNG = np.random.default_rng(66) + + +def _get_calibration_data(num_samples: int = 256): + calibration_data = [] + for _ in range(num_samples): + calibration_data.append( + {'conv2d_input': _RNG.uniform(size=(1, 28, 28, 1)).astype(np.float32)} + ) + return calibration_data -class QuantizerTest(googletest.TestCase): +class QuantizerTest(parameterized.TestCase): def setUp(self): super().setUp() @@ -45,7 +59,6 @@ def test_update_quantization_recipe_succeeds(self): operation_name=qtyping.TFLOperationName.FULLY_CONNECTED, algorithm_key=_AlgorithmName.MIN_MAX_UNIFORM_QUANT, op_config=new_op_config, - override_algorithm=True, ) updated_recipe = self._quantizer.get_quantization_recipe() self.assertLen(updated_recipe, 2) @@ -72,13 +85,93 @@ def test_load_quantization_recipe_succeeds(self): qt.load_quantization_recipe(new_recipe_path) self.assertEqual(qt.get_quantization_recipe(), new_recipe) - def test_quantize_succeeds(self): + @parameterized.parameters( + 'tests/recipes/conv_fc_mnist_a8w8_recipe.json', + 'tests/recipes/conv_fc_mnist_a16w8_recipe.json', + ) + def test_calibrate_required_recipe_succeeds(self, recipe_path): + recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path) + self._quantizer.load_quantization_recipe(recipe_path) + self.assertTrue(self._quantizer.need_calibration) + # Calibrate with empty state. + calib_data = _get_calibration_data() + calibration_result = self._quantizer.calibrate(calib_data) + self.assertLen(calibration_result, 13) + + @parameterized.parameters( + 'tests/recipes/conv_fc_mnist_a8w8_recipe.json', + 'tests/recipes/conv_fc_mnist_a16w8_recipe.json', + ) + def test_reloaded_calibration_succeeds(self, recipe_path): + recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path) + self._quantizer.load_quantization_recipe(recipe_path) + calib_data = _get_calibration_data() + calibration_result = self._quantizer.calibrate(calib_data) + # Load and calibrate again. + updated_calibration_result = self._quantizer.calibrate( + calib_data, previous_calibration_result=calibration_result + ) + self.assertLen(updated_calibration_result, 13) + self.assertNotEqual( + calibration_result['StatefulPartitionedCall:0'], + updated_calibration_result['StatefulPartitionedCall:0'], + ) + + @parameterized.parameters( + 'tests/recipes/conv_fc_mnist_drq_recipe.json', + 'tests/recipes/conv_fc_mnist_weight_only_recipe.json', + ) + def test_calibrate_nonrequired_recipe_succeeds(self, recipe_path): + recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path) + self._quantizer.load_quantization_recipe(recipe_path) + self.assertFalse(self._quantizer.need_calibration) + # Empty calibration result if no calibration is required. + calibration_result = self._quantizer.calibrate(_get_calibration_data()) + self.assertEmpty(calibration_result) + + def test_quantize_no_calibration_succeeds(self): self._quantizer.load_quantization_recipe(self._test_recipe_path) self.assertIsNone(self._quantizer._result.quantized_model) quant_result = self._quantizer.quantize() self.assertEqual(quant_result.recipe, self._test_recipe) self.assertIsNotNone(quant_result.quantized_model) + @parameterized.parameters( + 'tests/recipes/conv_fc_mnist_a8w8_recipe.json', + 'tests/recipes/conv_fc_mnist_a16w8_recipe.json', + ) + def test_quantize_calibration_needed_succeeds(self, recipe_path): + recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path) + with open(recipe_path) as json_file: + recipe = json.load(json_file) + + self._quantizer.load_quantization_recipe(recipe_path) + self.assertTrue(self._quantizer.need_calibration) + calibration_result = self._quantizer.calibrate(_get_calibration_data()) + + self.assertIsNone(self._quantizer._result.quantized_model) + quant_result = self._quantizer.quantize(calibration_result) + self.assertEqual(quant_result.recipe, recipe) + self.assertIsNotNone(quant_result.quantized_model) + + @parameterized.parameters( + 'tests/recipes/conv_fc_mnist_a8w8_recipe.json', + 'tests/recipes/conv_fc_mnist_a16w8_recipe.json', + ) + def test_quantize_calibration_needed_raise_error(self, recipe_path): + recipe_path = os.path.join(TEST_DATA_PREFIX_PATH, recipe_path) + + self._quantizer.load_quantization_recipe(recipe_path) + self.assertTrue(self._quantizer.need_calibration) + error_message = ( + 'Model quantization statistics values (QSVs) are required for the input' + ' recipe.' + ) + with self.assertRaisesWithPredicateMatch( + RuntimeError, lambda err: error_message in str(err) + ): + self._quantizer.quantize() + def test_quantize_no_recipe_raise_error(self): qt = quantizer.Quantizer(self._test_model_path, None) error_message = 'Can not quantize without a quantization recipe.' diff --git a/ai_edge_quantizer/recipe_manager.py b/ai_edge_quantizer/recipe_manager.py index 0e9e562..9dc9394 100644 --- a/ai_edge_quantizer/recipe_manager.py +++ b/ai_edge_quantizer/recipe_manager.py @@ -38,16 +38,6 @@ class OpQuantizationRecipe: default_factory=_OpQuantizationConfig ) - # Flag to check if this rule overrides the previous matched rule with - # different algorithm key. Used when the algorithm keys of previous matched - # config and the current config are different. When set to true, the - # previously matched config is ignored; otherwise, the current matched config - # is ignored. - # When the algorithm keys of both configs are the same, then this flag does - # not have any effect; the op_config of previously matched config is updated - # using the op_config of this one. - override_algorithm: bool = True - class RecipeManager: """Sets the quantization recipe for target model. @@ -74,7 +64,6 @@ def add_quantization_config( operation_name: _TFLOpName, op_config: Optional[_OpQuantizationConfig] = None, algorithm_key: str = algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT, - override_algorithm: bool = True, ) -> None: """Adds a quantization configuration. @@ -94,14 +83,12 @@ def add_quantization_config( default configuration. None or empty dict means the default configuration will be used. algorithm_key: Algorithm key to be applied. - override_algorithm: Flag to check if this rule overrides the previously - matched rule with different algorithm key. """ if op_config is None: op_config = _OpQuantizationConfig() config = OpQuantizationRecipe( - regex, operation_name, algorithm_key, op_config, override_algorithm + regex, operation_name, algorithm_key, op_config ) # Special care if trying to set all ops to some config. if config.operation == _TFLOpName.ALL_SUPPORTED: @@ -141,7 +128,6 @@ def add_quantization_config( configs.append(config) self._scope_configs[regex] = configs - # TODO: b/348469513 - Remove the override_algorithm flag. def get_quantization_configs( self, target_op_name: _TFLOpName, @@ -151,13 +137,8 @@ def get_quantization_configs( We respect the latest valid config and fall back to no quantization. Specifically, we search the quantization configuration in the order of the - scope configs. If there are two or more matching rules, if the same - quantization algorithms are assigned for both rules, then we will overwrite - the quantization config with the later one (if it is valid). If the assigned - algorithms are different,override_algorithm flag is used to see which - algorithm will be used. If the flag is True, the latter is used. If the flag - is False, the latter is ignored. We will fall to no quantization if no - matching rule is found or all matched configs are invalid. + scope configs. If there are two or more matching settings, the latest one + will be used. Args: @@ -168,10 +149,9 @@ def get_quantization_configs( Returns: A tuple of quantization algorithm, and quantization configuration. """ - result_key, result_config, selected_recipe = ( + result_key, result_config = ( AlgorithmName.NO_QUANTIZE, _OpQuantizationConfig(), - None, ) for scope_regex, recipes in self._scope_configs.items(): if re.search(scope_regex, scope_name): @@ -181,11 +161,6 @@ def get_quantization_configs( and recipe.operation != target_op_name ): continue - if ( - result_key != recipe.algorithm_key - and not recipe.override_algorithm - ): - continue selected_recipe = recipe # The selected recipe must contain a supported config. try: @@ -197,19 +172,6 @@ def get_quantization_configs( result_config = selected_recipe.op_config result_key = selected_recipe.algorithm_key - if ( - selected_recipe is not None - and selected_recipe.operation == _TFLOpName.ALL_SUPPORTED - and result_config != selected_recipe.op_config - ): - logging.warning( - 'Ignored operation %s with config %s under scope_regex %s. Since the' - ' specified quantization config is not supported at the moment.' - ' (Triggered by quantizing ALL_SUPPORTED ops under a scope.)', - target_op_name, - selected_recipe.op_config, - selected_recipe.regex, - ) return result_key, result_config def get_quantization_recipe(self) -> ModelQuantizationRecipe: @@ -226,7 +188,6 @@ def get_quantization_recipe(self) -> ModelQuantizationRecipe: config['operation'] = quant_config.operation config['algorithm_key'] = quant_config.algorithm_key config['op_config'] = quant_config.op_config.to_dict() - config['override_algorithm'] = quant_config.override_algorithm ret.append(config) return ret @@ -246,7 +207,6 @@ def load_quantization_recipe( config['operation'], _OpQuantizationConfig.from_dict(config['op_config']), config['algorithm_key'], - config['override_algorithm'], ) def need_calibration(self) -> bool: diff --git a/ai_edge_quantizer/recipe_manager_test.py b/ai_edge_quantizer/recipe_manager_test.py index 6242e99..125ac78 100644 --- a/ai_edge_quantizer/recipe_manager_test.py +++ b/ai_edge_quantizer/recipe_manager_test.py @@ -380,7 +380,6 @@ def test_get_full_quantization_config(self): }, 'execution_mode': 'SRQ', }, - 'override_algorithm': True, }, { 'regex': '.*', @@ -395,7 +394,6 @@ def test_get_full_quantization_config(self): }, 'execution_mode': 'WEIGHT_ONLY', }, - 'override_algorithm': True, }, { 'regex': '.*/Dense/.*', @@ -410,7 +408,6 @@ def test_get_full_quantization_config(self): }, 'execution_mode': 'WEIGHT_ONLY', }, - 'override_algorithm': True, }, { 'regex': '.*/Dense_1/.*', @@ -425,7 +422,6 @@ def test_get_full_quantization_config(self): }, 'execution_mode': 'WEIGHT_ONLY', }, - 'override_algorithm': True, }, { 'regex': '.*/Dense_1/.*', @@ -440,7 +436,6 @@ def test_get_full_quantization_config(self): }, 'execution_mode': 'WEIGHT_ONLY', }, - 'override_algorithm': True, }, ] self.assertEqual( @@ -463,7 +458,6 @@ def test_load_from_full_quantization_config(self): }, 'execution_mode': 'WEIGHT_ONLY', }, - 'override_algorithm': True, }, { 'regex': '.*/Dense/.*', @@ -478,7 +472,6 @@ def test_load_from_full_quantization_config(self): }, 'execution_mode': 'DRQ', }, - 'override_algorithm': True, }, ] self._recipe_manager.load_quantization_recipe(full_quantization_config) @@ -534,7 +527,6 @@ def test_load_from_full_quantization_config_full_integer(self): }, 'execution_mode': 'SRQ', }, - 'override_algorithm': True, }, { 'regex': '.*', @@ -549,7 +541,6 @@ def test_load_from_full_quantization_config_full_integer(self): }, 'execution_mode': 'WEIGHT_ONLY', }, - 'override_algorithm': True, }, { 'regex': '.*/Dense/.*', @@ -564,7 +555,6 @@ def test_load_from_full_quantization_config_full_integer(self): }, 'execution_mode': 'DRQ', }, - 'override_algorithm': True, }, ] self._recipe_manager.load_quantization_recipe(full_quantization_config) diff --git a/ai_edge_quantizer/tests/mnist_test.py b/ai_edge_quantizer/tests/mnist_test.py index 148d115..646cee4 100644 --- a/ai_edge_quantizer/tests/mnist_test.py +++ b/ai_edge_quantizer/tests/mnist_test.py @@ -1,6 +1,8 @@ """E2E tests for the quantizer using a toy MNIST model.""" from absl.testing import parameterized +import numpy as np + from tensorflow.python.platform import googletest from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer @@ -11,6 +13,25 @@ _TensorQuantConfig = qtyping.TensorQuantizationConfig _OpQuantConfig = qtyping.OpQuantizationConfig +_RNG = np.random.default_rng(66) + + +def _get_dummy_data(num_samples): + data = [] + for _ in range(num_samples): + data.append( + {'conv2d_input': _RNG.uniform(size=(1, 28, 28, 1)).astype(np.float32)} + ) + return data + + +def _get_calibration_data(num_samples: int = 256): + return _get_dummy_data(num_samples) + + +def _get_test_data(num_samples: int = 8): + return _get_dummy_data(num_samples) + class MNISTTest(parameterized.TestCase): @@ -19,6 +40,7 @@ def setUp(self): self.float_model_path = test_utils.get_path_to_datafile( 'models/conv_fc_mnist.tflite' ) + self._quantizer = quantizer.Quantizer(self.float_model_path) @parameterized.product( execution_mode=[_OpExecutionMode.WEIGHT_ONLY, _OpExecutionMode.DRQ], @@ -33,8 +55,7 @@ def test_mnist_toy_model_int8_weight_only( # TODO: b/335254997 - fail when trying to use unsupported recipe. if execution_mode == _OpExecutionMode.DRQ and not symmetric_weight: return - qt = quantizer.Quantizer(self.float_model_path) - qt.update_quantization_recipe( + self._quantizer.update_quantization_recipe( regex='.*', operation_name=_OpName.FULLY_CONNECTED, op_config=_OpQuantConfig( @@ -46,30 +67,57 @@ def test_mnist_toy_model_int8_weight_only( execution_mode=execution_mode, ), ) - _ = qt.quantize() + _ = self._quantizer.quantize() # Check model size. - self.assertLess(len(qt._result.quantized_model), 55000) + self.assertLess(len(self._quantizer._result.quantized_model), 55000) - comparion_result = qt.compare(error_metrics='mse') - # Check weight tensors. - tolerance = 1e-2 if channel_wise_weight else 1e-1 - conv_weight_mse = comparion_result['sequential/conv2d/Conv2D'] - self.assertLess(conv_weight_mse, tolerance) - fc1_weight_mse = comparion_result['arith.constant1'] - self.assertLess(fc1_weight_mse, tolerance) - fc2_weight_mse = comparion_result['arith.constant'] - self.assertLess(fc2_weight_mse, tolerance) - # check logits. - logits_mse = comparion_result['sequential/dense_1/MatMul'] - self.assertLess(logits_mse, tolerance) - # check final output. - output_mse = comparion_result['StatefulPartitionedCall:0'] - self.assertLess(output_mse, tolerance) - # TODO: b/345503484 - Check weight tensor type of the quantized model. + comparion_result = self._quantizer.compare(error_metrics='mse') + self._check_comparion_result( + comparion_result, + weight_tolerance=1e-2 if channel_wise_weight else 1e-1, + logits_tolerance=1e-2 if channel_wise_weight else 1e-1, + output_tolerance=1e-4 if channel_wise_weight else 1e-2, + ) + + @parameterized.product( + execution_mode=[_OpExecutionMode.WEIGHT_ONLY, _OpExecutionMode.DRQ], + symmetric_weight=[True, False], + ) + def test_mnist_toy_model_int4_weight_only( + self, execution_mode, symmetric_weight + ): + + # Asym DRQ is not supported. + # TODO: b/335254997 - Fail when trying to use unsupported recipe. + if execution_mode == _OpExecutionMode.DRQ and not symmetric_weight: + return + self._quantizer.update_quantization_recipe( + regex='.*', + operation_name=_OpName.FULLY_CONNECTED, + op_config=_OpQuantConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=4, + symmetric=symmetric_weight, + channel_wise=True, + ), + execution_mode=execution_mode, + ), + ) + _ = self._quantizer.quantize() + # Check model size. + self.assertLess(len(self._quantizer._result.quantized_model), 30000) + + comparion_result = self._quantizer.compare(error_metrics='mse') + # TODO: b/346787369 - Update the weight tolerance for int4. + self._check_comparion_result( + comparion_result, + weight_tolerance=1000, + logits_tolerance=2, + output_tolerance=1e-2, + ) def test_mnist_toy_model_fp16_weight_only(self): - qt = quantizer.Quantizer(self.float_model_path) - qt.update_quantization_recipe( + self._quantizer.update_quantization_recipe( regex='.*', algorithm_key=quantizer.AlgorithmName.FLOAT_CASTING, operation_name=_OpName.FULLY_CONNECTED, @@ -80,26 +128,63 @@ def test_mnist_toy_model_fp16_weight_only(self): execution_mode=_OpExecutionMode.WEIGHT_ONLY, ), ) - _ = qt.quantize() + _ = self._quantizer.quantize() + # Check model size. + self.assertLess(len(self._quantizer._result.quantized_model), 105000) + + comparion_result = self._quantizer.compare(error_metrics='mse') + self._check_comparion_result( + comparion_result, + weight_tolerance=1e-5, + logits_tolerance=1e-5, + output_tolerance=1e-5, + ) + + @parameterized.parameters( + 'recipes/conv_fc_mnist_a8w8_recipe.json', + 'recipes/conv_fc_mnist_a16w8_recipe.json', + ) + def test_mnist_toy_model_full_intege(self, recipe_path): + recipe_path = test_utils.get_path_to_datafile(recipe_path) + self._quantizer.load_quantization_recipe(recipe_path) + self.assertTrue(self._quantizer.need_calibration) + calibration_result = self._quantizer.calibrate(_get_calibration_data()) + quant_result = self._quantizer.quantize(calibration_result) # Check model size. - self.assertLess(len(qt._result.quantized_model), 105000) + self.assertLess(len(quant_result.quantized_model), 55000) + + comparion_result = self._quantizer.compare( + error_metrics='mse', signature_test_data=_get_test_data() + ) + self._check_comparion_result( + comparion_result, + weight_tolerance=1e-2, + logits_tolerance=1e-1, + output_tolerance=1e-4, + ) - comparion_result = qt.compare(error_metrics='mse') + # TODO: b/345503484 - Check weight tensor type of the quantized model. + def _check_comparion_result( + self, + comparion_result, + weight_tolerance, + logits_tolerance, + output_tolerance, + ): # Check weight tensors. - tolerance = 1e-5 conv_weight_mse = comparion_result['sequential/conv2d/Conv2D'] - self.assertLess(conv_weight_mse, tolerance) + self.assertLess(conv_weight_mse, weight_tolerance) fc1_weight_mse = comparion_result['arith.constant1'] - self.assertLess(fc1_weight_mse, tolerance) + self.assertLess(fc1_weight_mse, weight_tolerance) fc2_weight_mse = comparion_result['arith.constant'] - self.assertLess(fc2_weight_mse, tolerance) + self.assertLess(fc2_weight_mse, weight_tolerance) # check logits. logits_mse = comparion_result['sequential/dense_1/MatMul'] - self.assertLess(logits_mse, tolerance) + self.assertLess(logits_mse, logits_tolerance) # check final output. output_mse = comparion_result['StatefulPartitionedCall:0'] - self.assertLess(output_mse, tolerance) - # TODO: b/345503484 - Check weight tensor type of the quantized model. + self.assertLess(output_mse, output_tolerance) + if __name__ == '__main__': googletest.main() diff --git a/ai_edge_quantizer/tests/recipes/conv_fc_mnist_a16w8_recipe.json b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_a16w8_recipe.json new file mode 100644 index 0000000..9105e80 --- /dev/null +++ b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_a16w8_recipe.json @@ -0,0 +1 @@ +[{"regex": ".*", "operation": "*", "algorithm_key": "min_max_uniform_quantize", "op_config": {"activation_tensor_config": {"num_bits": 16, "symmetric": true, "channel_wise": false, "dtype": "INT"}, "weight_tensor_config": {"num_bits": 8, "symmetric": true, "channel_wise": true, "dtype": "INT"}, "execution_mode": "SRQ"}}] \ No newline at end of file diff --git a/ai_edge_quantizer/tests/recipes/conv_fc_mnist_a8w8_recipe.json b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_a8w8_recipe.json new file mode 100644 index 0000000..3034fcc --- /dev/null +++ b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_a8w8_recipe.json @@ -0,0 +1 @@ +[{"regex": ".*", "operation": "*", "algorithm_key": "min_max_uniform_quantize", "op_config": {"activation_tensor_config": {"num_bits": 8, "symmetric": false, "channel_wise": false, "dtype": "INT"}, "weight_tensor_config": {"num_bits": 8, "symmetric": true, "channel_wise": true, "dtype": "INT"}, "execution_mode": "SRQ"}}] \ No newline at end of file diff --git a/ai_edge_quantizer/tests/recipes/conv_fc_mnist_drq_recipe.json b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_drq_recipe.json index 53b0b30..29d9ab4 100644 --- a/ai_edge_quantizer/tests/recipes/conv_fc_mnist_drq_recipe.json +++ b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_drq_recipe.json @@ -1 +1 @@ -[{"regex": ".*", "operation": "FULLY_CONNECTED", "algorithm_key": "min_max_uniform_quantize", "op_config": {"weight_tensor_config": {"num_bits": 8, "symmetric": true, "channel_wise": true, "dtype": "INT"}, "execution_mode": "DRQ"}, "override_algorithm": true}] \ No newline at end of file +[{"regex": ".*", "operation": "FULLY_CONNECTED", "algorithm_key": "min_max_uniform_quantize", "op_config": {"weight_tensor_config": {"num_bits": 8, "symmetric": true, "channel_wise": true, "dtype": "INT"}, "execution_mode": "DRQ"}}] \ No newline at end of file diff --git a/ai_edge_quantizer/tests/recipes/conv_fc_mnist_weight_only_recipe.json b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_weight_only_recipe.json index b7d1e16..8290b95 100644 --- a/ai_edge_quantizer/tests/recipes/conv_fc_mnist_weight_only_recipe.json +++ b/ai_edge_quantizer/tests/recipes/conv_fc_mnist_weight_only_recipe.json @@ -1 +1 @@ -[{"regex": ".*", "operation": "FULLY_CONNECTED", "algorithm_key": "min_max_uniform_quantize", "op_config": {"weight_tensor_config": {"num_bits": 8, "symmetric": false, "channel_wise": true, "dtype": "INT"}, "execution_mode": "WEIGHT_ONLY"}, "override_algorithm": true}] \ No newline at end of file +[{"regex": ".*", "operation": "FULLY_CONNECTED", "algorithm_key": "min_max_uniform_quantize", "op_config": {"weight_tensor_config": {"num_bits": 8, "symmetric": false, "channel_wise": true, "dtype": "INT"}, "execution_mode": "WEIGHT_ONLY"}}] \ No newline at end of file