Skip to content

Commit 72e11ba

Browse files
authored
[CMSIS-NN] Moved TFLite model making to common area (#10939)
* [CMSIS-NN] Moved TFLite model making to common area Change-Id: Ic4dbc1919ff0b481c05daf7e57cf9b055c714c9c * Fixed lint issues with tensorflow import Change-Id: I7a520beec9c244e9c790d3e82733c2fb476f7e5e * Resolved merge conflict with main Change-Id: Iefe58dd321efae6eae26cd54a31c5923d0f1e32b * Made TFLite layer creation explicit Change-Id: I7fbf6a5a2163c1fada49477f86d84f1bc09bd57c * Lint fix: added a missing docstring Change-Id: If1fb8bb09c538c04e333ccab65a20cff247a504d
1 parent 9fd279b commit 72e11ba

File tree

3 files changed

+172
-139
lines changed

3 files changed

+172
-139
lines changed

python/tvm/relay/testing/tflite.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Common utilities for creating TFLite models"""
18+
from distutils.version import LooseVersion
19+
import numpy as np
20+
import pytest
21+
import tvm
22+
23+
pytest.importorskip("tflite")
24+
pytest.importorskip("tensorflow")
25+
import tflite.Model # pylint: disable=wrong-import-position
26+
import tensorflow as tf # pylint: disable=wrong-import-position
27+
28+
29+
class TFLiteModel:
30+
"""Creates TFLite Model and facilitates reference data generation"""
31+
32+
def __init__(self, dtype):
33+
self.serial_model = None # This is what TFLite convert() provides
34+
self.dtype = dtype # This is the dtype of graph inputs
35+
self.shape_dict = {}
36+
self.dtype_dict = {}
37+
38+
def create_conv2d_single(self, kernel_shape, strides, padding, dilation, activation):
39+
"""Returns tf.function that creates TFLite Conv2d layer"""
40+
41+
@tf.function
42+
def conv2d_single_function(ifm_tensor):
43+
"""Returns TFLite Conv2d layer"""
44+
op = tf.nn.conv2d(
45+
ifm_tensor,
46+
filters=tf.constant(
47+
np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
48+
dtype=tf.float32,
49+
),
50+
strides=[1, strides[0], strides[1], 1],
51+
padding=padding,
52+
dilations=dilation,
53+
)
54+
if activation == "RELU":
55+
op = tf.nn.relu(op)
56+
elif activation == "NONE":
57+
pass
58+
else:
59+
assert False, "Unsupported activation {}".format(activation)
60+
return op
61+
62+
return conv2d_single_function
63+
64+
def create_tflite_model(self, tfl_function, shapes, ranges=None):
65+
"""Creates TFLite serial graph"""
66+
tensor_specs = []
67+
for i, shape in enumerate(shapes):
68+
input_name = "input_" + str(i)
69+
self.shape_dict.update({input_name: shape})
70+
self.dtype_dict.update({input_name: self.dtype})
71+
tensor_specs.append(tf.TensorSpec(shape, dtype=tf.float32, name=input_name))
72+
concrete_func = tfl_function.get_concrete_function(*tensor_specs)
73+
74+
if not ranges:
75+
ranges = [(0, 1) for _ in shapes]
76+
77+
def representative_dataset():
78+
for _ in range(100):
79+
inputs = []
80+
for i, shape in enumerate(shapes):
81+
data = np.random.uniform(
82+
low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
83+
).astype("float32")
84+
inputs.append(data)
85+
86+
yield inputs
87+
88+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
89+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
90+
converter.representative_dataset = representative_dataset
91+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
92+
converter.inference_input_type = tf.int8
93+
converter.inference_output_type = tf.int8
94+
self.serial_model = converter.convert()
95+
96+
def convert_to_relay(self):
97+
"""Converts TFLite serialized graph into Relay"""
98+
assert self.serial_model is not None, "TFLite model is empty!"
99+
100+
tflite_model = tflite.Model.Model.GetRootAsModel(self.serial_model, 0)
101+
relay_module, relay_params = tvm.relay.frontend.from_tflite(
102+
tflite_model, self.shape_dict, self.dtype_dict
103+
)
104+
return relay_module, relay_params
105+
106+
def generate_randomized_input_data(self, seed, shape, dtype):
107+
"""Generates randomized input numpy arrays based on shape and dtype."""
108+
random_state = np.random.RandomState(seed)
109+
random_data = None
110+
if dtype == np.float32:
111+
random_data = random_state.uniform(-1, 1, size).astype(dtype)
112+
else:
113+
low = np.iinfo(dtype).min
114+
high = np.iinfo(dtype).max + 1
115+
random_data = random_state.randint(low, high, shape, dtype)
116+
return random_data
117+
118+
# pylint: disable=import-outside-toplevel
119+
def generate_reference_data(self):
120+
"""
121+
This method uses TFLite reference kernels to generate reference output.
122+
It returns randomized inputs and reference outputs.
123+
"""
124+
assert self.serial_model is not None, "TFLite model was not created."
125+
126+
output_tolerance = None
127+
if tf.__version__ < LooseVersion("2.5.0"):
128+
output_tolerance = 1
129+
interpreter = tf.lite.Interpreter(model_content=self.serial_model)
130+
else:
131+
output_tolerance = 0
132+
interpreter = tf.lite.Interpreter(
133+
model_content=self.serial_model,
134+
experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF,
135+
experimental_preserve_all_tensors=False,
136+
)
137+
138+
interpreter.allocate_tensors()
139+
input_details = interpreter.get_input_details()
140+
output_details = interpreter.get_output_details()
141+
142+
# Generate predictable randomized input
143+
seed = 0
144+
input_data = {}
145+
for input_detail in input_details:
146+
input_values = self.generate_randomized_input_data(
147+
seed, input_detail["shape"], input_detail["dtype"]
148+
)
149+
interpreter.set_tensor(input_detail["index"], input_values)
150+
input_data.update({input_detail["name"]: input_values})
151+
152+
interpreter.invoke()
153+
154+
# Obtain the expected output from interpreter
155+
expected_output_data = {}
156+
for output_detail in output_details:
157+
expected_output_data.update(
158+
{output_detail["name"]: interpreter.get_tensor(output_detail["index"])}
159+
)
160+
161+
return input_data, expected_output_data, output_tolerance

tests/python/contrib/test_cmsisnn/test_conv2d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,12 @@
3535
from utils import (
3636
skip_if_no_reference_system,
3737
make_module,
38-
create_conv2d_tflite_relay_models,
3938
get_range_for_dtype_str,
4039
get_same_padding,
4140
get_conv2d_qnn_params,
4241
make_qnn_relu,
4342
assert_partitioned_function,
4443
assert_no_external_function,
45-
generate_ref_data_tflite,
4644
)
4745

4846

@@ -314,25 +312,30 @@ def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding,
314312
interface_api = "c"
315313
use_unpacked_api = True
316314
test_runner = AOT_USMP_CORSTONE300_RUNNER
317-
318315
dtype = "int8"
319-
tflite_model, relay_mod, params = create_conv2d_tflite_relay_models(
320-
ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
316+
317+
from tvm.relay.testing.tflite import TFLiteModel
318+
319+
tfl_model = TFLiteModel(dtype)
320+
conv2d_function = tfl_model.create_conv2d_single(
321+
kernel_shape, strides, padding, dilation, activation
321322
)
323+
tfl_model.create_tflite_model(conv2d_function, [ifm_shape])
324+
relay_mod, relay_params = tfl_model.convert_to_relay()
322325

323-
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params)
326+
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params)
324327

325328
# validate pattern matching
326329
assert_partitioned_function(relay_mod, cmsisnn_mod)
327330

328331
# validate CMSIS-NN output against TFLite output
329-
input_map, output_map, output_tolerance = generate_ref_data_tflite(tflite_model)
332+
input_map, output_map, output_tolerance = tfl_model.generate_reference_data()
330333
compile_and_run(
331334
AOTTestModel(
332335
module=cmsisnn_mod,
333336
inputs=input_map,
334337
outputs=output_map,
335-
params=params,
338+
params=relay_params,
336339
output_tolerance=output_tolerance,
337340
),
338341
test_runner,

tests/python/contrib/test_cmsisnn/utils.py

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -225,134 +225,3 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype):
225225
)
226226
if fused_activation_fn == "RELU":
227227
return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax)
228-
229-
230-
def generate_random_input_data(seed, shape, dtype):
231-
"""
232-
Generates randomized input numpy arrays based on shape and dtype
233-
"""
234-
random_state = np.random.RandomState(seed)
235-
if dtype == np.float32:
236-
return random_state.uniform(-1, 1, size).astype(dtype)
237-
else:
238-
low = np.iinfo(dtype).min
239-
high = np.iinfo(dtype).max + 1
240-
return random_state.randint(low, high, shape, dtype)
241-
242-
243-
def generate_ref_data_tflite(model):
244-
"""
245-
This method uses TFLite reference kernels to generate reference output.
246-
Random input generator is used to get the input data.
247-
It returns randomized inputs and reference outputs.
248-
"""
249-
import tensorflow as tf
250-
from distutils.version import LooseVersion
251-
252-
output_tolerance = None
253-
if tf.__version__ < LooseVersion("2.5.0"):
254-
output_tolerance = 1
255-
interpreter = tf.lite.Interpreter(model_content=model)
256-
else:
257-
from tensorflow.lite.python.interpreter import OpResolverType
258-
259-
output_tolerance = 0
260-
interpreter = tf.lite.Interpreter(
261-
model_content=model,
262-
experimental_op_resolver_type=OpResolverType.BUILTIN_REF,
263-
experimental_preserve_all_tensors=False,
264-
)
265-
266-
interpreter.allocate_tensors()
267-
input_details = interpreter.get_input_details()
268-
output_details = interpreter.get_output_details()
269-
270-
# Generate predictable randomized input
271-
seed = 0
272-
input_data = {}
273-
for input_detail in input_details:
274-
input_values = generate_random_input_data(
275-
seed, input_detail["shape"], input_detail["dtype"]
276-
)
277-
interpreter.set_tensor(input_detail["index"], input_values)
278-
input_data.update({input_detail["name"]: input_values})
279-
280-
interpreter.invoke()
281-
282-
# Obtain the expected output from interpreter
283-
expected_output_data = {}
284-
for output_detail in output_details:
285-
expected_output_data.update(
286-
{output_detail["name"]: interpreter.get_tensor(output_detail["index"])}
287-
)
288-
289-
return input_data, expected_output_data, output_tolerance
290-
291-
292-
def create_conv2d_tflite_model(ifm_shape, kernel_shape, strides, dilation, padding, activation):
293-
"""This method prepares TFlite graph with a single Conv2d layer"""
294-
import tensorflow as tf
295-
296-
class Model(tf.Module):
297-
@tf.function
298-
def tf_function(self, x):
299-
# Use tf.nn API to create the model
300-
tf_strides = [1, strides[0], strides[1], 1]
301-
op = tf.nn.conv2d(
302-
x,
303-
filters=tf.constant(
304-
np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
305-
dtype=tf.float32,
306-
),
307-
strides=tf_strides,
308-
padding=padding,
309-
dilations=dilation,
310-
)
311-
if activation:
312-
op = tf.nn.relu(op)
313-
return op
314-
315-
model = Model()
316-
concrete_func = model.tf_function.get_concrete_function(
317-
tf.TensorSpec(ifm_shape, dtype=tf.float32)
318-
)
319-
320-
def representative_dataset():
321-
for _ in range(100):
322-
data = np.random.rand(*tuple(ifm_shape))
323-
yield [data.astype(np.float32)]
324-
325-
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
326-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
327-
converter.representative_dataset = representative_dataset
328-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
329-
converter.inference_input_type = tf.int8
330-
converter.inference_output_type = tf.int8
331-
tflite_model = converter.convert()
332-
return tflite_model
333-
334-
335-
def create_conv2d_tflite_relay_models(
336-
ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
337-
):
338-
"""
339-
This method creates a conv2d TFLite layer and prepared TFLite model from it.
340-
Converts that into the Relay module and params.
341-
Returns TFLite model, Relay module and params.
342-
"""
343-
pytest.importorskip("tflite")
344-
import tflite.Model
345-
346-
serialized_tflite_model = create_conv2d_tflite_model(
347-
ifm_shape, kernel_shape, strides, dilation, padding, activation
348-
)
349-
350-
tflite_model = tflite.Model.Model.GetRootAsModel(serialized_tflite_model, 0)
351-
352-
relay_module, params = relay.frontend.from_tflite(
353-
tflite_model,
354-
shape_dict={"input": ifm_shape},
355-
dtype_dict={"input": dtype},
356-
)
357-
358-
return serialized_tflite_model, relay_module, params

0 commit comments

Comments
 (0)