Skip to content

Commit

Permalink
Enable calibration and model validation with XNNPACK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705626044
  • Loading branch information
v-dziuba authored and copybara-github committed Dec 12, 2024
1 parent 3c0f35b commit 353abd9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
3 changes: 2 additions & 1 deletion ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Calibrator:
def __init__(
self,
float_tflite: Union[str, bytes],
num_threads: int = 16,
):
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)

Expand All @@ -50,7 +51,7 @@ def __init__(
" the model (e.g., if it is already quantized)."
)
self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
float_tflite
float_tflite, use_xnnpack=True, num_threads=num_threads
)
# Tensor name to tensor content.
self._tensor_content_map: dict[str, Any] = {}
Expand Down
25 changes: 16 additions & 9 deletions ai_edge_quantizer/model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def _setup_validation_interpreter(
model: bytes,
signature_input: dict[str, Any],
signature_key: Optional[str],
use_reference_kernel: bool,
use_xnnpack: bool,
num_threads: int,
) -> tuple[Any, int, dict[str, Any]]:
"""Setup the interpreter for validation given a signature key.
Expand All @@ -216,15 +217,15 @@ def _setup_validation_interpreter(
signature_input: A dictionary of input tensor name and its value.
signature_key: The signature key to be used for invoking the models. If the
model only has one signature, this can be set to None.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use xnnpack for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A tuple of interpreter, subgraph_index and tensor_name_to_details.
"""

interpreter = utils.create_tfl_interpreter(
tflite_model=model, use_reference_kernel=use_reference_kernel
tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
)
utils.invoke_interpreter_signature(
interpreter, signature_input, signature_key
Expand All @@ -247,7 +248,8 @@ def compare_model(
test_data: dict[str, Iterable[dict[str, Any]]],
error_metric: str,
compare_fn: Callable[[Any, Any], float],
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 16,
) -> ComparisonResult:
"""Compares model tensors over a model signature using the compare_fn.
Expand All @@ -266,8 +268,8 @@ def compare_model(
compare_fn: a comparison function to be used for calculating the statistics,
this function must be taking in two ArrayLike strcuture and output a
single float value.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use xnnpack for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A ComparisonResult object.
Expand All @@ -282,12 +284,17 @@ def compare_model(
reference_model,
signature_input,
signature_key,
use_reference_kernel,
use_xnnpack,
num_threads,
)
)
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
_setup_validation_interpreter(
target_model, signature_input, signature_key, use_reference_kernel
target_model,
signature_input,
signature_key,
use_xnnpack,
num_threads,
)
)
# Compare the cached tensor values.
Expand Down
13 changes: 9 additions & 4 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,15 @@ def calibrate(
self,
calibration_data: dict[str, Iterable[_SignatureInput]],
previous_calibration_result: Optional[_CalibrationResult] = None,
num_threads: int = 16,
) -> _CalibrationResult:
"""Calibrates the float model (required by static range quantization).
Args:
calibration_data: Calibration data for a model signature.
previous_calibration_result: Previous calibration result to be loaded. The
calibration process will be resumed from the previous result.
num_threads: Number of threads to use for calibration.
Returns:
Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}).
Expand All @@ -233,7 +235,7 @@ def calibrate(
if not self.need_calibration:
return {}

calib = calibrator.Calibrator(self.float_model)
calib = calibrator.Calibrator(self.float_model, num_threads=num_threads)
if previous_calibration_result is not None:
calib.load_model_qsvs(previous_calibration_result)
calib.calibrate(calibration_data, self._recipe_manager)
Expand Down Expand Up @@ -297,7 +299,8 @@ def validate(
self,
test_data: Optional[dict[str, Iterable[_SignatureInput]]] = None,
error_metrics: str = 'mse',
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 16,
) -> model_validator.ComparisonResult:
"""Numerical validation of the quantized model for a model signature.
Expand All @@ -314,7 +317,8 @@ def validate(
data that will be used for validation. If set to None, random normal
distributed data will be used for all signatures in the model.
error_metrics: Error metrics to be used for comparison.
use_reference_kernel: Whether to use the reference kernel for validation.
use_xnnpack: Whether to use the xnnpack library for validation.
num_threads: Number of threads to use for validation.
Returns:
The comparison result.
Expand All @@ -330,7 +334,8 @@ def validate(
test_data,
error_metrics,
validation_utils.get_validation_func(error_metrics),
use_reference_kernel=use_reference_kernel,
use_xnnpack=use_xnnpack,
num_threads=num_threads,
)

def _get_quantization_params(
Expand Down
12 changes: 7 additions & 5 deletions ai_edge_quantizer/utils/tfl_interpreter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@
def create_tfl_interpreter(
tflite_model: Union[str, bytes],
allocate_tensors: bool = True,
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 16,
) -> tfl.Interpreter:
"""Creates a TFLite interpreter from a model file.
Args:
tflite_model: Model file path or bytes.
allocate_tensors: Whether to allocate tensors.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A TFLite interpreter.
Expand All @@ -47,12 +48,13 @@ def create_tfl_interpreter(
with gfile.GFile(tflite_model, "rb") as f:
tflite_model = f.read()

if use_reference_kernel:
op_resolver = tfl.OpResolverType.BUILTIN_REF
if use_xnnpack:
op_resolver = tfl.OpResolverType.BUILTIN
else:
op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
tflite_interpreter = tfl.Interpreter(
model_content=bytes(tflite_model),
num_threads=num_threads,
experimental_op_resolver_type=op_resolver,
experimental_preserve_all_tensors=True,
)
Expand Down

0 comments on commit 353abd9

Please sign in to comment.