Skip to content

Commit

Permalink
Add full integer quantization for SELECT_V2 in Quantizer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698579890
  • Loading branch information
v-dziuba authored and copybara-github committed Nov 21, 2024
1 parent 230c1fc commit 18aeb2a
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ai_edge_quantizer/algorithm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class AlgorithmName(str, enum.Enum):
_TFLOpName.LOGISTIC, # Sigmoid
_TFLOpName.SLICE,
_TFLOpName.SUM,
_TFLOpName.SELECT_V2,
),
(
naive_min_max_quantize.materialize_input,
Expand Down Expand Up @@ -118,6 +119,7 @@ class AlgorithmName(str, enum.Enum):
naive_min_max_quantize.materialize_softmax_and_logistic,
naive_min_max_quantize.materialize_slice,
naive_min_max_quantize.materialize_sum,
naive_min_max_quantize.materialize_select_v2,
),
):
register_quantized_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,23 @@ def materialize_slice(
)


def materialize_select_v2(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
tensor_name_to_qsv: dict[str, Any],
) -> list[qtyping.TensorTransformationParams]:
"""Materialize tensors in tfl.select_v2."""
return utils.materialize_standard_op(
op_info,
graph_info,
tensor_name_to_qsv,
constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
inputs_to_ignore=[
0,
], # Condition tensor does not need to be quantized.
)


def materialize_sum(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

_TFLOpName = qtyping.TFLOperationName
_ComputePrecision = qtyping.ComputePrecision
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_QuantTransformation = qtyping.QuantTransformation
_OpTestInfo = naive_min_max_test_utils.OpTestInfo

_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
"../../../tests/models"
)

_DEFAULT_WEIGHT_QUANT_SETTING = (
naive_min_max_test_utils.DEFAULT_WEIGHT_QUANT_SETTING
)


class SelectV2Test(naive_min_max_test_utils.NaiveMinMaxQuantizeTest):

def setUp(self):
super().setUp()
np.random.seed(666)
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH, "single_select_v2.tflite"
)
self._op_test_info = _OpTestInfo(
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
op_tensor_names={},
input_range=(np.array([[-10]]), np.array([[10]])),
output_range=(np.array([[-10]]), np.array([[10]])),
)
# The test model has one subgraph for now.
self._graph_info = qtyping.GraphInfo(
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
buffers=self._op_test_info.test_model.buffers,
)

@parameterized.parameters(
8,
16,
)
def test_materialize_select_v2_succeeds(self, num_bits):
activation_tensor_config = _TensorQuantConfig(
num_bits=num_bits,
symmetric=True,
granularity=qtyping.QuantGranularity.TENSORWISE,
)
op_quant_config = qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_DEFAULT_WEIGHT_QUANT_SETTING,
compute_precision=_ComputePrecision.INTEGER, # SRQ.
)
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.SELECT_V2,
subgraph_op_index=subgraph_op_id,
op_quant_config=op_quant_config,
)

# Test settings.
op_tensor_names = {}
op_tensor_names["input"] = "selectv2_condition_tensor:0"
op_tensor_names["input2"] = "selectv2_t_tensor:0"
op_tensor_names["input3"] = "selectv2_e_tensor:0"
op_tensor_names["output"] = "PartitionedCall:0"
self._op_test_info.op_tensor_names = op_tensor_names
self._test_no_weights_op(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_select_v2,
same_input_output_params=True,
inputs_to_ignore=[0],
)


if __name__ == "__main__":
googletest.main()
6 changes: 4 additions & 2 deletions ai_edge_quantizer/default_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@
"OUTPUT",
"SLICE",
"EMBEDDING_LOOKUP",
"SUM"
"SUM",
"SELECT_V2"
],
"static_wi8_ai8": [
"ADD",
Expand All @@ -193,7 +194,8 @@
"OUTPUT",
"SLICE",
"EMBEDDING_LOOKUP",
"SUM"
"SUM",
"SELECT_V2"
],
"static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
"static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"],
Expand Down
1 change: 1 addition & 0 deletions ai_edge_quantizer/qtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class TFLOperationName(str, enum.Enum):
LOGISTIC = 'LOGISTIC'
SLICE = 'SLICE'
SUM = 'SUM'
SELECT_V2 = 'SELECT_V2'


class QuantizeMode(enum.Enum):
Expand Down
117 changes: 117 additions & 0 deletions ai_edge_quantizer/tests/end_to_end_tests/select_v2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""E2E tests for the quantizer for model with slice."""

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer import quantizer
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

_OpExecutionMode = qtyping.OpExecutionMode
_OpName = qtyping.TFLOperationName
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_OpQuantConfig = qtyping.OpQuantizationConfig

_RNG = np.random.default_rng(66)


def _get_dummy_data(num_samples):
data = []
for _ in range(num_samples):
data.append({
'condition_tensor': _RNG.uniform(size=(1, 16)).astype(np.bool),
'e_tensor': _RNG.uniform(size=(1, 16)).astype(np.float32),
't_tensor': _RNG.uniform(size=(1, 16)).astype(np.float32),
})
return data


def _get_calibration_data(num_samples: int = 64):
calibration_samples = _get_dummy_data(num_samples)
calibration_data = {'selectv2': calibration_samples}
return calibration_data


def _get_test_data(num_samples: int = 16):
return _get_calibration_data(num_samples)


class SelectV2Test(parameterized.TestCase):

def _custom_setup(self, test_model_file):
super().setUp()
self.float_model_path = test_utils.get_path_to_datafile(
f'../models/{test_model_file}'
)
self._quantizer = quantizer.Quantizer(self.float_model_path)

@parameterized.parameters(
('../../recipes/default_a8w8_recipe.json', 9), # int8.
('../../recipes/default_a16w8_recipe.json', 7), # int16.
)
def test_select_v2_model_full_integer(self, recipe_path, tensor_type):
self._custom_setup('single_select_v2.tflite')
recipe_path = test_utils.get_path_to_datafile(recipe_path)
self._quantizer.load_quantization_recipe(recipe_path)
self.assertTrue(self._quantizer.need_calibration)
calibration_result = self._quantizer.calibrate(_get_calibration_data())
quantization_result = self._quantizer.quantize(calibration_result)

# Check input/output tensor type.
quantized_model = tfl_flatbuffer_utils.read_model(
quantization_result.quantized_model
)
self.assertLen(quantized_model.subgraphs, 1)
subgraph = quantized_model.subgraphs[0]
subgraph_tensors = subgraph.tensors
self.assertLen(subgraph.inputs, 3)
condition_tensor = subgraph_tensors[subgraph.inputs[0]]
e_tensor = subgraph_tensors[subgraph.inputs[1]]
t_tensor = subgraph_tensors[subgraph.inputs[2]]
output_tensor = subgraph_tensors[subgraph.outputs[0]]
# See schema_py_generated.py for type code.
self.assertEqual(condition_tensor.type, 6) # bool.
self.assertEqual(e_tensor.type, tensor_type)
self.assertEqual(t_tensor.type, tensor_type)
self.assertEqual(output_tensor.type, tensor_type)

comparison_result = self._quantizer.validate(
error_metrics='mse', test_data=_get_test_data()
)
self._check_comparison_result(
comparison_result,
output_tolerance=1e-4,
)

# TODO: b/345503484 - Check weight tensor type of the quantized model.
def _check_comparison_result(
self,
comparison_result,
output_tolerance,
):
# TODO: b/357959309 - Use comparison result directly for testing.
comparison_result = comparison_result.get_all_tensor_results()
# Check final output.
output_mse = comparison_result['PartitionedCall:0']
self.assertLess(output_mse, output_tolerance)


if __name__ == '__main__':
googletest.main()
Binary file not shown.
1 change: 1 addition & 0 deletions ai_edge_quantizer/utils/tfl_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC,
_TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE,
_TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM,
_TFLOpName.SELECT_V2: schema_py_generated.BuiltinOperator.SELECT_V2,
})

TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
Expand Down

0 comments on commit 18aeb2a

Please sign in to comment.