From d9ff43840c71f7e760386dccfcf720c4f545b810 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Mon, 4 Apr 2022 10:36:02 -0700 Subject: [PATCH] Dump input values to a file. --- tflitehub/mobilenet_v2_int8_test.py | 52 +++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tflitehub/mobilenet_v2_int8_test.py b/tflitehub/mobilenet_v2_int8_test.py index a3fcefee..93c81fba 100644 --- a/tflitehub/mobilenet_v2_int8_test.py +++ b/tflitehub/mobilenet_v2_int8_test.py @@ -5,8 +5,59 @@ import numpy import test_util +import os +import random +import re +from typing import Any, Callable, Mapping, Sequence, Set, Tuple, Union +import numpy as np + model_path = "https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v2_1.0_int8/mobilenet_v2_1.00_224_int8.tflite" +def to_mlir_type(dtype: np.dtype) -> str: + """Returns a string that denotes the type 'dtype' in MLIR style.""" + if not isinstance(dtype, np.dtype): + # Handle np.int8 _not_ being a dtype. + dtype = np.dtype(dtype) + bits = dtype.itemsize * 8 + if np.issubdtype(dtype, np.integer): + return f"i{bits}" + elif np.issubdtype(dtype, np.floating): + return f"f{bits}" + else: + raise TypeError(f"Expected integer or floating type, but got {dtype}") + + +def get_shape_and_dtype(array: np.ndarray, + allow_non_mlir_dtype: bool = False) -> str: + shape_dtype = [str(dim) for dim in list(array.shape)] + if np.issubdtype(array.dtype, np.number): + shape_dtype.append(to_mlir_type(array.dtype)) + elif np.issubdtype(array.dtype, bool): + shape_dtype.append("i8") + elif allow_non_mlir_dtype: + shape_dtype.append(f"") + else: + raise TypeError(f"Expected integer or floating type, but got {array.dtype}") + return "x".join(shape_dtype) + + +def save_input_values(inputs: Sequence[np.ndarray], + file_path: str = None) -> str: + result = [] + for array in inputs: + shape_dtype = get_shape_and_dtype(array) + if np.issubdtype(array.dtype, bool): + values = 1 if array else 0 + else: + values = " ".join([str(x) for x in array.flatten()]) + result.append(f"--function_input={shape_dtype}={values}") + result = "\n".join(result) + print("Saving IREE input values to: %s", file_path) + with open(file_path, "w") as f: + f.write(result) + f.write("\n") + return result + class MobilenetV2Int8Test(test_util.TFLiteModelTest): def __init__(self, *args, **kwargs): super(MobilenetV2Int8Test, self).__init__(model_path, *args, **kwargs) @@ -26,6 +77,7 @@ def generate_inputs(self, input_details): inputs = imagenet_test_data.generate_input(self.workdir, input_details) # Normalize inputs to [-1, 1]. inputs = (inputs.astype('float32') / 127.5) - 1 + save_input_values(inputs, '/tmp/iree-samples/inputs.txt') return [inputs] def test_compile_tflite(self):