Skip to content

Commit 4d77595

Browse files
author
Ashutosh Parkhi
committed
Made TFLite layer creation explicit
Change-Id: I7fbf6a5a2163c1fada49477f86d84f1bc09bd57c
1 parent f2a3484 commit 4d77595

File tree

3 files changed

+47
-170
lines changed

3 files changed

+47
-170
lines changed

python/tvm/relay/testing/tflite.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,44 +35,53 @@ def __init__(self, dtype):
3535
self.shape_dict = {}
3636
self.dtype_dict = {}
3737

38-
@tf.function
39-
def conv2d_single_function(self, ifm_tensor, args):
40-
"""Returns TFLite Conv2d layer"""
41-
assert len(args) == 6, "Conv2D needs (ifm_shape, kernel_shape, strides, padding, dilation)"
42-
_, kernel_shape, strides, padding, dilation, activation = args
43-
op = tf.nn.conv2d(
44-
ifm_tensor,
45-
filters=tf.constant(
46-
np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
47-
dtype=tf.float32,
48-
),
49-
strides=[1, strides[0], strides[1], 1],
50-
padding=padding,
51-
dilations=dilation,
52-
)
53-
if activation == "RELU":
54-
op = tf.nn.relu(op)
55-
elif activation == "NONE":
56-
pass
57-
else:
58-
assert False, "Unsupported activation {}".format(activation)
59-
return op
60-
61-
def create_tflite_model(self, op_type, *args):
62-
"""Returns TFLite serial graph, Relay module, Relay params based on op_type"""
63-
concrete_func = None
64-
input_shape = None
65-
if op_type == "conv2d_single":
66-
input_shape = args[0]
67-
ifm_tensor = tf.TensorSpec(input_shape, dtype=tf.float32, name="input")
68-
concrete_func = self.conv2d_single_function.get_concrete_function(ifm_tensor, args)
69-
else:
70-
assert False, "Unsupported op_type {}".format(op_type)
38+
def create_conv2d_single(self, kernel_shape, strides, padding, dilation, activation):
39+
@tf.function
40+
def conv2d_single_function(ifm_tensor):
41+
"""Returns TFLite Conv2d layer"""
42+
op = tf.nn.conv2d(
43+
ifm_tensor,
44+
filters=tf.constant(
45+
np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
46+
dtype=tf.float32,
47+
),
48+
strides=[1, strides[0], strides[1], 1],
49+
padding=padding,
50+
dilations=dilation,
51+
)
52+
if activation == "RELU":
53+
op = tf.nn.relu(op)
54+
elif activation == "NONE":
55+
pass
56+
else:
57+
assert False, "Unsupported activation {}".format(activation)
58+
return op
59+
60+
return conv2d_single_function
61+
62+
def create_tflite_model(self, tfl_function, shapes, ranges=None):
63+
"""Creates TFLite serial graph"""
64+
tensor_specs = []
65+
for i, shape in enumerate(shapes):
66+
input_name = "input_" + str(i)
67+
self.shape_dict.update({input_name: shape})
68+
self.dtype_dict.update({input_name: self.dtype})
69+
tensor_specs.append(tf.TensorSpec(shape, dtype=tf.float32, name=input_name))
70+
concrete_func = tfl_function.get_concrete_function(*tensor_specs)
71+
72+
if not ranges:
73+
ranges = [(0, 1) for _ in shapes]
7174

7275
def representative_dataset():
7376
for _ in range(100):
74-
data = np.random.rand(*tuple(input_shape))
75-
yield [data.astype(np.float32)]
77+
inputs = []
78+
for i, shape in enumerate(shapes):
79+
data = np.random.uniform(
80+
low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
81+
).astype("float32")
82+
inputs.append(data)
83+
84+
yield inputs
7685

7786
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
7887
converter.optimizations = [tf.lite.Optimize.DEFAULT]
@@ -81,8 +90,6 @@ def representative_dataset():
8190
converter.inference_input_type = tf.int8
8291
converter.inference_output_type = tf.int8
8392
self.serial_model = converter.convert()
84-
self.shape_dict = {"input": input_shape}
85-
self.dtype_dict = {"input": self.dtype}
8693

8794
def convert_to_relay(self):
8895
"""Converts TFLite serialized graph into Relay"""

tests/python/contrib/test_cmsisnn/test_conv2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,10 @@ def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding,
317317
from tvm.relay.testing.tflite import TFLiteModel
318318

319319
tfl_model = TFLiteModel(dtype)
320-
tfl_model.create_tflite_model(
321-
"conv2d_single", ifm_shape, kernel_shape, strides, padding, dilation, activation
320+
conv2d_function = tfl_model.create_conv2d_single(
321+
kernel_shape, strides, padding, dilation, activation
322322
)
323+
tfl_model.create_tflite_model(conv2d_function, [ifm_shape])
323324
relay_mod, relay_params = tfl_model.convert_to_relay()
324325

325326
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params)

tests/python/contrib/test_cmsisnn/utils.py

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -225,134 +225,3 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype):
225225
)
226226
if fused_activation_fn == "RELU":
227227
return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax)
228-
229-
230-
def generate_random_input_data(seed, shape, dtype):
231-
"""
232-
Generates randomized input numpy arrays based on shape and dtype
233-
"""
234-
random_state = np.random.RandomState(seed)
235-
if dtype == np.float32:
236-
return random_state.uniform(-1, 1, size).astype(dtype)
237-
else:
238-
low = np.iinfo(dtype).min
239-
high = np.iinfo(dtype).max + 1
240-
return random_state.randint(low, high, shape, dtype)
241-
242-
243-
def generate_ref_data_tflite(model):
244-
"""
245-
This method uses TFLite reference kernels to generate reference output.
246-
Random input generator is used to get the input data.
247-
It returns randomized inputs and reference outputs.
248-
"""
249-
import tensorflow as tf
250-
from distutils.version import LooseVersion
251-
252-
output_tolerance = None
253-
if tf.__version__ < LooseVersion("2.5.0"):
254-
output_tolerance = 1
255-
interpreter = tf.lite.Interpreter(model_content=model)
256-
else:
257-
from tensorflow.lite.python.interpreter import OpResolverType
258-
259-
output_tolerance = 0
260-
interpreter = tf.lite.Interpreter(
261-
model_content=model,
262-
experimental_op_resolver_type=OpResolverType.BUILTIN_REF,
263-
experimental_preserve_all_tensors=False,
264-
)
265-
266-
interpreter.allocate_tensors()
267-
input_details = interpreter.get_input_details()
268-
output_details = interpreter.get_output_details()
269-
270-
# Generate predictable randomized input
271-
seed = 0
272-
input_data = {}
273-
for input_detail in input_details:
274-
input_values = generate_random_input_data(
275-
seed, input_detail["shape"], input_detail["dtype"]
276-
)
277-
interpreter.set_tensor(input_detail["index"], input_values)
278-
input_data.update({input_detail["name"]: input_values})
279-
280-
interpreter.invoke()
281-
282-
# Obtain the expected output from interpreter
283-
expected_output_data = {}
284-
for output_detail in output_details:
285-
expected_output_data.update(
286-
{output_detail["name"]: interpreter.get_tensor(output_detail["index"])}
287-
)
288-
289-
return input_data, expected_output_data, output_tolerance
290-
291-
292-
def create_conv2d_tflite_model(ifm_shape, kernel_shape, strides, dilation, padding, activation):
293-
"""This method prepares TFlite graph with a single Conv2d layer"""
294-
import tensorflow as tf
295-
296-
class Model(tf.Module):
297-
@tf.function
298-
def tf_function(self, x):
299-
# Use tf.nn API to create the model
300-
tf_strides = [1, strides[0], strides[1], 1]
301-
op = tf.nn.conv2d(
302-
x,
303-
filters=tf.constant(
304-
np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
305-
dtype=tf.float32,
306-
),
307-
strides=tf_strides,
308-
padding=padding,
309-
dilations=dilation,
310-
)
311-
if activation:
312-
op = tf.nn.relu(op)
313-
return op
314-
315-
model = Model()
316-
concrete_func = model.tf_function.get_concrete_function(
317-
tf.TensorSpec(ifm_shape, dtype=tf.float32)
318-
)
319-
320-
def representative_dataset():
321-
for _ in range(100):
322-
data = np.random.rand(*tuple(ifm_shape))
323-
yield [data.astype(np.float32)]
324-
325-
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
326-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
327-
converter.representative_dataset = representative_dataset
328-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
329-
converter.inference_input_type = tf.int8
330-
converter.inference_output_type = tf.int8
331-
tflite_model = converter.convert()
332-
return tflite_model
333-
334-
335-
def create_conv2d_tflite_relay_models(
336-
ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
337-
):
338-
"""
339-
This method creates a conv2d TFLite layer and prepared TFLite model from it.
340-
Converts that into the Relay module and params.
341-
Returns TFLite model, Relay module and params.
342-
"""
343-
pytest.importorskip("tflite")
344-
import tflite.Model
345-
346-
serialized_tflite_model = create_conv2d_tflite_model(
347-
ifm_shape, kernel_shape, strides, dilation, padding, activation
348-
)
349-
350-
tflite_model = tflite.Model.Model.GetRootAsModel(serialized_tflite_model, 0)
351-
352-
relay_module, params = relay.frontend.from_tflite(
353-
tflite_model,
354-
shape_dict={"input": ifm_shape},
355-
dtype_dict={"input": dtype},
356-
)
357-
358-
return serialized_tflite_model, relay_module, params

0 commit comments

Comments
 (0)