Skip to content

Commit 7766ab2

Browse files
author
Sebastian Boblest
authored
Add unidirectional sequence lstm (#11183)
* UnidirectionalLSTM added * fixed missing import * fixed pylint warnings * black formatted tflite.py * corrections according to reviewer comments * fixed black formatting * just to trigger the CI again * assertion now tests that there are exactly 24 input tensors. * black formatted tflite.py * added explanatory comment regarding unused imports * removed unused import * nothing * nothing * added some details in a comment about the differences in unbind regarding to the version in common.py * improved comment on unbind * fix of black issue
1 parent 2a2d910 commit 7766ab2

File tree

2 files changed

+213
-5
lines changed

2 files changed

+213
-5
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ..backend.name_transforms import sanitize_name
3434
from .common import ExprTable
3535
from .common import infer_shape as _infer_shape
36-
from .common import to_int_list, shape_of
36+
from .common import lstm_cell, to_int_list, shape_of
3737
from .tflite_flexbuffer import FlexBufferDecoder
3838

3939
__all__ = ["from_tflite"]
@@ -173,6 +173,7 @@ def __init__(self, model, subgraph, exp_tab):
173173
"TRANSPOSE_CONV": self.convert_transpose_conv,
174174
"TRANSPOSE": self.convert_transpose,
175175
"UNPACK": self.convert_unpack,
176+
"UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm,
176177
"WHERE": self.convert_select,
177178
"ZEROS_LIKE": self.convert_zeros_like,
178179
}
@@ -220,6 +221,41 @@ def check_unsupported_ops(self):
220221
if len(raise_msg) > 0:
221222
raise tvm.error.OpNotImplemented(raise_msg)
222223

224+
def unbind(self, data, axis=1):
225+
"""
226+
This is a modified version compared to the one in common.py.
227+
The onnx version takes a relay.Expr.Call, the tflite
228+
version a TensorWrapper. Also this version by default splits
229+
along axis 1 and not axis 0 as the onnx version.
230+
231+
Parameters
232+
----------
233+
data : tvm.relay.frontend.tflite.TensorWrapper
234+
Input tensor
235+
axis : int
236+
Axis along which tensor is split.
237+
Returns
238+
-------
239+
result : List[relay.Expr]
240+
The sequence of computed tensors
241+
"""
242+
shape = to_int_list(self.get_tensor_shape(data))
243+
if axis >= len(shape):
244+
msg = "Please check input dim, it shouldn't be greater than or equal to rank."
245+
raise AttributeError(msg)
246+
247+
selections = shape[axis]
248+
shape.pop(axis)
249+
timestep = 0 # Reshape to make time step as the first dim
250+
shape.insert(timestep, selections)
251+
res_split = _op.split(
252+
_op.reshape(self.get_expr(data.tensor_idx), tuple(shape)), selections, timestep
253+
)
254+
ret = []
255+
for i in range(selections):
256+
ret.append(_op.squeeze(res_split[i], axis=[timestep]))
257+
return _expr.TupleWrapper(_expr.Tuple(ret), selections)
258+
223259
def convert_op_to_relay(self):
224260
"""Convert TFLite ops to relay ops"""
225261
for op_idx in range(self.subgraph.OperatorsLength()):
@@ -2715,6 +2751,148 @@ def convert_unpack(self, op):
27152751

27162752
return squeezed
27172753

2754+
def convert_unidirectional_sequence_lstm(self, op):
2755+
"""Long Short Term Memory for TFLite implementation."""
2756+
if self.is_quantized(op):
2757+
raise tvm.error.OpNotImplemented(
2758+
"TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not supported yet."
2759+
)
2760+
2761+
input_tensors = self.get_input_tensors(op)
2762+
assert len(input_tensors) == 24, "input tensors length should be == 24"
2763+
2764+
# Extract input tensor from saved model
2765+
input_tensor = input_tensors[0]
2766+
2767+
# Extract tensors from input tensors from saved model
2768+
# Input weights
2769+
input_input_weights = input_tensors[1]
2770+
input_forget_weights = input_tensors[2]
2771+
input_cell_weights = input_tensors[3]
2772+
input_output_weights = input_tensors[4]
2773+
# Recurrent weights
2774+
recurrent_input_weights = input_tensors[5]
2775+
recurrent_forget_weights = input_tensors[6]
2776+
recurrent_cell_weights = input_tensors[7]
2777+
recurrent_output_weights = input_tensors[8]
2778+
# inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied
2779+
# there locations are -1 in the flatbuffer
2780+
# Bias weights
2781+
input_gate_bias = input_tensors[12]
2782+
forget_gate_bias = input_tensors[13]
2783+
cell_gate_bias = input_tensors[14]
2784+
output_gate_bias = input_tensors[15]
2785+
2786+
# State input
2787+
output_state_in = input_tensors[18]
2788+
cell_state_in = input_tensors[19]
2789+
2790+
# Extract output tensor from saved model
2791+
output_tensors = self.get_output_tensors(op)
2792+
assert len(output_tensors) == 1, "output tensors length should be 1"
2793+
X_steps = self.unbind(input_tensor, axis=1)
2794+
weights_dict = {}
2795+
2796+
# hidden_state_weights is equivalent to output_state_in in tflite model
2797+
out_state_in_shape = tuple(self.get_tensor_shape(output_state_in))
2798+
out_state_in_dtype = self.get_tensor_type_str(output_state_in.tensor.Type())
2799+
out_state_in_expr = _op.zeros(out_state_in_shape, dtype=out_state_in_dtype)
2800+
weights_dict["hidden_state"] = _op.split(out_state_in_expr, 1)[0]
2801+
2802+
# cell_state_weights is equivalent to output_state_in tflite model
2803+
cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in))
2804+
cell_state_in_dtype = self.get_tensor_type_str(cell_state_in.tensor.Type())
2805+
cell_state_in_expr = _op.zeros(cell_state_in_shape, dtype=cell_state_in_dtype)
2806+
weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0]
2807+
2808+
# Process weight matrix of input: w_inp
2809+
# Concatenate of [input_input_weight, input_forget_weights,
2810+
# input_cell_weights, input_output_weights]
2811+
input_input_weights_default_values = self.get_tensor_value(input_input_weights)
2812+
input_input_weights_op = _op.split(
2813+
_op.const(input_input_weights_default_values.tolist()), 1
2814+
)
2815+
input_output_weights_default_values = self.get_tensor_value(input_output_weights)
2816+
input_output_weights_op = _op.split(
2817+
_op.const(input_output_weights_default_values.tolist()), 1
2818+
)
2819+
input_forget_weights_default_values = self.get_tensor_value(input_forget_weights)
2820+
input_forget_weights_op = _op.split(
2821+
_op.const(input_forget_weights_default_values.tolist()), 1
2822+
)
2823+
input_cell_weights_default_values = self.get_tensor_value(input_cell_weights)
2824+
input_cell_weights_op = _op.split(_op.const(input_cell_weights_default_values.tolist()), 1)
2825+
weights_dict["w_inp"] = _op.concatenate(
2826+
[
2827+
_op.squeeze(input_input_weights_op[0]),
2828+
_op.squeeze(input_forget_weights_op[0]),
2829+
_op.squeeze(input_cell_weights_op[0]),
2830+
_op.squeeze(input_output_weights_op[0]),
2831+
],
2832+
axis=0,
2833+
)
2834+
2835+
# Process weight matrix of hidden state:
2836+
# w_hid to support lstm_cell function. Not used in tflite
2837+
recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights)
2838+
recurrent_input_weights_op = _op.split(
2839+
_op.const(recurrent_input_weights_values.tolist()), 1
2840+
)
2841+
recurrent_output_weights_values = self.get_tensor_value(recurrent_output_weights)
2842+
recurrent_output_weights_op = _op.split(
2843+
_op.const(recurrent_output_weights_values.tolist()), 1
2844+
)
2845+
recurrent_forget_weights_values = self.get_tensor_value(recurrent_forget_weights)
2846+
recurrent_forget_weights_op = _op.split(
2847+
_op.const(recurrent_forget_weights_values.tolist()), 1
2848+
)
2849+
recurrent_cell_weights_values = self.get_tensor_value(recurrent_cell_weights)
2850+
recurrent_cell_weights_op = _op.split(_op.const(recurrent_cell_weights_values.tolist()), 1)
2851+
weights_dict["w_hid"] = _op.concatenate(
2852+
[
2853+
recurrent_input_weights_op[0],
2854+
recurrent_forget_weights_op[0],
2855+
recurrent_cell_weights_op[0],
2856+
recurrent_output_weights_op[0],
2857+
],
2858+
axis=0,
2859+
)
2860+
2861+
# Process weight matrix of bias: b_inp
2862+
input_gate_bias_values = self.get_tensor_value(input_gate_bias)
2863+
input_gate_bias_op = _op.split(_op.const(input_gate_bias_values.tolist()), 1)
2864+
output_gate_bias_values = self.get_tensor_value(output_gate_bias)
2865+
output_gate_bias_op = _op.split(_op.const(output_gate_bias_values.tolist()), 1)
2866+
forget_gate_bias_values = self.get_tensor_value(forget_gate_bias)
2867+
forget_gate_bias_op = _op.split(_op.const(forget_gate_bias_values.tolist()), 1)
2868+
cell_gate_bias_values = self.get_tensor_value(cell_gate_bias)
2869+
cell_gate_bias_op = _op.split(_op.const(cell_gate_bias_values.tolist()), 1)
2870+
weights_dict["b_inp"] = _op.concatenate(
2871+
[
2872+
input_gate_bias_op[0],
2873+
forget_gate_bias_op[0],
2874+
cell_gate_bias_op[0],
2875+
output_gate_bias_op[0],
2876+
],
2877+
axis=0,
2878+
)
2879+
2880+
# Process weight matrix of hidden bias:
2881+
# b_hid (with the same shape as b_inp)
2882+
gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type())
2883+
weights_dict["b_hid"] = _op.split(
2884+
_op.const(
2885+
np.zeros(_infer_shape(weights_dict["b_inp"]), dtype=gate_bias_dtype),
2886+
dtype=gate_bias_dtype,
2887+
),
2888+
1,
2889+
)[0]
2890+
2891+
outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict)
2892+
2893+
output = _op.stack(outputs, axis=1)
2894+
return output
2895+
27182896
def convert_batch_to_space_nd(self, op):
27192897
"""batch_to_space_nd implementation."""
27202898

tests/python/frontend/tflite/test_forward.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,7 +1867,7 @@ def tf_function(self, x):
18671867
model,
18681868
export_dir,
18691869
signatures=model.tf_function.get_concrete_function(
1870-
tf.TensorSpec(data.shape, tf.float32, name="input"),
1870+
tf.TensorSpec(data.shape, tf.float32, name="input")
18711871
),
18721872
)
18731873

@@ -3759,8 +3759,7 @@ def test_forward_prelu():
37593759
np.full((32, 3), 0.2, dtype="float32"),
37603760
)
37613761
_test_prelu(
3762-
np.random.uniform(-5, 5, size=(32, 3)).astype("float32"),
3763-
np.full((3), 0.2, dtype="float32"),
3762+
np.random.uniform(-5, 5, size=(32, 3)).astype("float32"), np.full((3), 0.2, dtype="float32")
37643763
)
37653764

37663765

@@ -4693,6 +4692,36 @@ def representative_dataset():
46934692
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
46944693

46954694

4695+
#######################################################################
4696+
# Unidirectional Sequence LSTM
4697+
# ---------------------
4698+
def test_forward_unidirectional_sequence_lstm():
4699+
"""Test the UnidirectionalSequenceLSTM TFLite"""
4700+
if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
4701+
tflite_model_file = download_testdata(
4702+
"https://github.com/SebastianBoblestETAS/nn_models/blob/ce49c5de64889493161ca4194a20e0fd5eb707e6/lstm_1_in_3_out_2_ts_4.tflite?raw=true",
4703+
"lstm_1_in_3_out_2_ts_4.tflite",
4704+
)
4705+
with open(tflite_model_file, "rb") as f:
4706+
tflite_model_buf = f.read()
4707+
4708+
data = np.array(
4709+
[
4710+
[
4711+
[0.5488135, 0.71518934, 0.60276335],
4712+
[0.5448832, 0.4236548, 0.6458941],
4713+
[0.4375872, 0.891773, 0.96366274],
4714+
[0.3834415, 0.79172504, 0.5288949],
4715+
]
4716+
],
4717+
dtype="float32",
4718+
)
4719+
4720+
tflite_output = run_tflite_graph(tflite_model_buf, data)
4721+
tvm_output = run_tvm_graph(tflite_model_buf, data, "serving_default_input_1:0")
4722+
tvm.testing.assert_allclose(tflite_output, tvm_output)
4723+
4724+
46964725
#######################################################################
46974726
# Quantized SSD Mobilenet
46984727
# -----------------------
@@ -4930,10 +4959,11 @@ def test_prevent_tensorflow_dynamic_range():
49304959
test_forward_leaky_relu()
49314960
test_forward_relu_n1_to_1()
49324961
test_forward_log_softmax()
4933-
test_forward_prelu()
49344962
test_forward_fully_connected()
49354963
test_forward_l2_normalization()
49364964
test_forward_local_response_normalization()
4965+
test_forward_prelu()
4966+
test_forward_unidirectional_sequence_lstm()
49374967

49384968
# Elemwise
49394969
test_all_elemwise()

0 commit comments

Comments
 (0)