|
33 | 33 | from ..backend.name_transforms import sanitize_name |
34 | 34 | from .common import ExprTable |
35 | 35 | from .common import infer_shape as _infer_shape |
36 | | -from .common import to_int_list, shape_of |
| 36 | +from .common import lstm_cell, to_int_list, shape_of |
37 | 37 | from .tflite_flexbuffer import FlexBufferDecoder |
38 | 38 |
|
39 | 39 | __all__ = ["from_tflite"] |
@@ -173,6 +173,7 @@ def __init__(self, model, subgraph, exp_tab): |
173 | 173 | "TRANSPOSE_CONV": self.convert_transpose_conv, |
174 | 174 | "TRANSPOSE": self.convert_transpose, |
175 | 175 | "UNPACK": self.convert_unpack, |
| 176 | + "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, |
176 | 177 | "WHERE": self.convert_select, |
177 | 178 | "ZEROS_LIKE": self.convert_zeros_like, |
178 | 179 | } |
@@ -220,6 +221,41 @@ def check_unsupported_ops(self): |
220 | 221 | if len(raise_msg) > 0: |
221 | 222 | raise tvm.error.OpNotImplemented(raise_msg) |
222 | 223 |
|
| 224 | + def unbind(self, data, axis=1): |
| 225 | + """ |
| 226 | + This is a modified version compared to the one in common.py. |
| 227 | + The onnx version takes a relay.Expr.Call, the tflite |
| 228 | + version a TensorWrapper. Also this version by default splits |
| 229 | + along axis 1 and not axis 0 as the onnx version. |
| 230 | +
|
| 231 | + Parameters |
| 232 | + ---------- |
| 233 | + data : tvm.relay.frontend.tflite.TensorWrapper |
| 234 | + Input tensor |
| 235 | + axis : int |
| 236 | + Axis along which tensor is split. |
| 237 | + Returns |
| 238 | + ------- |
| 239 | + result : List[relay.Expr] |
| 240 | + The sequence of computed tensors |
| 241 | + """ |
| 242 | + shape = to_int_list(self.get_tensor_shape(data)) |
| 243 | + if axis >= len(shape): |
| 244 | + msg = "Please check input dim, it shouldn't be greater than or equal to rank." |
| 245 | + raise AttributeError(msg) |
| 246 | + |
| 247 | + selections = shape[axis] |
| 248 | + shape.pop(axis) |
| 249 | + timestep = 0 # Reshape to make time step as the first dim |
| 250 | + shape.insert(timestep, selections) |
| 251 | + res_split = _op.split( |
| 252 | + _op.reshape(self.get_expr(data.tensor_idx), tuple(shape)), selections, timestep |
| 253 | + ) |
| 254 | + ret = [] |
| 255 | + for i in range(selections): |
| 256 | + ret.append(_op.squeeze(res_split[i], axis=[timestep])) |
| 257 | + return _expr.TupleWrapper(_expr.Tuple(ret), selections) |
| 258 | + |
223 | 259 | def convert_op_to_relay(self): |
224 | 260 | """Convert TFLite ops to relay ops""" |
225 | 261 | for op_idx in range(self.subgraph.OperatorsLength()): |
@@ -2715,6 +2751,148 @@ def convert_unpack(self, op): |
2715 | 2751 |
|
2716 | 2752 | return squeezed |
2717 | 2753 |
|
| 2754 | + def convert_unidirectional_sequence_lstm(self, op): |
| 2755 | + """Long Short Term Memory for TFLite implementation.""" |
| 2756 | + if self.is_quantized(op): |
| 2757 | + raise tvm.error.OpNotImplemented( |
| 2758 | + "TFlite quantized UNIDIRECTIONALSEQUENCELSTM operator is not supported yet." |
| 2759 | + ) |
| 2760 | + |
| 2761 | + input_tensors = self.get_input_tensors(op) |
| 2762 | + assert len(input_tensors) == 24, "input tensors length should be == 24" |
| 2763 | + |
| 2764 | + # Extract input tensor from saved model |
| 2765 | + input_tensor = input_tensors[0] |
| 2766 | + |
| 2767 | + # Extract tensors from input tensors from saved model |
| 2768 | + # Input weights |
| 2769 | + input_input_weights = input_tensors[1] |
| 2770 | + input_forget_weights = input_tensors[2] |
| 2771 | + input_cell_weights = input_tensors[3] |
| 2772 | + input_output_weights = input_tensors[4] |
| 2773 | + # Recurrent weights |
| 2774 | + recurrent_input_weights = input_tensors[5] |
| 2775 | + recurrent_forget_weights = input_tensors[6] |
| 2776 | + recurrent_cell_weights = input_tensors[7] |
| 2777 | + recurrent_output_weights = input_tensors[8] |
| 2778 | + # inputs 9, 10, 11, 16, 17, 20, 21, 22, 23 are not occupied |
| 2779 | + # there locations are -1 in the flatbuffer |
| 2780 | + # Bias weights |
| 2781 | + input_gate_bias = input_tensors[12] |
| 2782 | + forget_gate_bias = input_tensors[13] |
| 2783 | + cell_gate_bias = input_tensors[14] |
| 2784 | + output_gate_bias = input_tensors[15] |
| 2785 | + |
| 2786 | + # State input |
| 2787 | + output_state_in = input_tensors[18] |
| 2788 | + cell_state_in = input_tensors[19] |
| 2789 | + |
| 2790 | + # Extract output tensor from saved model |
| 2791 | + output_tensors = self.get_output_tensors(op) |
| 2792 | + assert len(output_tensors) == 1, "output tensors length should be 1" |
| 2793 | + X_steps = self.unbind(input_tensor, axis=1) |
| 2794 | + weights_dict = {} |
| 2795 | + |
| 2796 | + # hidden_state_weights is equivalent to output_state_in in tflite model |
| 2797 | + out_state_in_shape = tuple(self.get_tensor_shape(output_state_in)) |
| 2798 | + out_state_in_dtype = self.get_tensor_type_str(output_state_in.tensor.Type()) |
| 2799 | + out_state_in_expr = _op.zeros(out_state_in_shape, dtype=out_state_in_dtype) |
| 2800 | + weights_dict["hidden_state"] = _op.split(out_state_in_expr, 1)[0] |
| 2801 | + |
| 2802 | + # cell_state_weights is equivalent to output_state_in tflite model |
| 2803 | + cell_state_in_shape = tuple(self.get_tensor_shape(cell_state_in)) |
| 2804 | + cell_state_in_dtype = self.get_tensor_type_str(cell_state_in.tensor.Type()) |
| 2805 | + cell_state_in_expr = _op.zeros(cell_state_in_shape, dtype=cell_state_in_dtype) |
| 2806 | + weights_dict["cell_state"] = _op.split(cell_state_in_expr, 1)[0] |
| 2807 | + |
| 2808 | + # Process weight matrix of input: w_inp |
| 2809 | + # Concatenate of [input_input_weight, input_forget_weights, |
| 2810 | + # input_cell_weights, input_output_weights] |
| 2811 | + input_input_weights_default_values = self.get_tensor_value(input_input_weights) |
| 2812 | + input_input_weights_op = _op.split( |
| 2813 | + _op.const(input_input_weights_default_values.tolist()), 1 |
| 2814 | + ) |
| 2815 | + input_output_weights_default_values = self.get_tensor_value(input_output_weights) |
| 2816 | + input_output_weights_op = _op.split( |
| 2817 | + _op.const(input_output_weights_default_values.tolist()), 1 |
| 2818 | + ) |
| 2819 | + input_forget_weights_default_values = self.get_tensor_value(input_forget_weights) |
| 2820 | + input_forget_weights_op = _op.split( |
| 2821 | + _op.const(input_forget_weights_default_values.tolist()), 1 |
| 2822 | + ) |
| 2823 | + input_cell_weights_default_values = self.get_tensor_value(input_cell_weights) |
| 2824 | + input_cell_weights_op = _op.split(_op.const(input_cell_weights_default_values.tolist()), 1) |
| 2825 | + weights_dict["w_inp"] = _op.concatenate( |
| 2826 | + [ |
| 2827 | + _op.squeeze(input_input_weights_op[0]), |
| 2828 | + _op.squeeze(input_forget_weights_op[0]), |
| 2829 | + _op.squeeze(input_cell_weights_op[0]), |
| 2830 | + _op.squeeze(input_output_weights_op[0]), |
| 2831 | + ], |
| 2832 | + axis=0, |
| 2833 | + ) |
| 2834 | + |
| 2835 | + # Process weight matrix of hidden state: |
| 2836 | + # w_hid to support lstm_cell function. Not used in tflite |
| 2837 | + recurrent_input_weights_values = self.get_tensor_value(recurrent_input_weights) |
| 2838 | + recurrent_input_weights_op = _op.split( |
| 2839 | + _op.const(recurrent_input_weights_values.tolist()), 1 |
| 2840 | + ) |
| 2841 | + recurrent_output_weights_values = self.get_tensor_value(recurrent_output_weights) |
| 2842 | + recurrent_output_weights_op = _op.split( |
| 2843 | + _op.const(recurrent_output_weights_values.tolist()), 1 |
| 2844 | + ) |
| 2845 | + recurrent_forget_weights_values = self.get_tensor_value(recurrent_forget_weights) |
| 2846 | + recurrent_forget_weights_op = _op.split( |
| 2847 | + _op.const(recurrent_forget_weights_values.tolist()), 1 |
| 2848 | + ) |
| 2849 | + recurrent_cell_weights_values = self.get_tensor_value(recurrent_cell_weights) |
| 2850 | + recurrent_cell_weights_op = _op.split(_op.const(recurrent_cell_weights_values.tolist()), 1) |
| 2851 | + weights_dict["w_hid"] = _op.concatenate( |
| 2852 | + [ |
| 2853 | + recurrent_input_weights_op[0], |
| 2854 | + recurrent_forget_weights_op[0], |
| 2855 | + recurrent_cell_weights_op[0], |
| 2856 | + recurrent_output_weights_op[0], |
| 2857 | + ], |
| 2858 | + axis=0, |
| 2859 | + ) |
| 2860 | + |
| 2861 | + # Process weight matrix of bias: b_inp |
| 2862 | + input_gate_bias_values = self.get_tensor_value(input_gate_bias) |
| 2863 | + input_gate_bias_op = _op.split(_op.const(input_gate_bias_values.tolist()), 1) |
| 2864 | + output_gate_bias_values = self.get_tensor_value(output_gate_bias) |
| 2865 | + output_gate_bias_op = _op.split(_op.const(output_gate_bias_values.tolist()), 1) |
| 2866 | + forget_gate_bias_values = self.get_tensor_value(forget_gate_bias) |
| 2867 | + forget_gate_bias_op = _op.split(_op.const(forget_gate_bias_values.tolist()), 1) |
| 2868 | + cell_gate_bias_values = self.get_tensor_value(cell_gate_bias) |
| 2869 | + cell_gate_bias_op = _op.split(_op.const(cell_gate_bias_values.tolist()), 1) |
| 2870 | + weights_dict["b_inp"] = _op.concatenate( |
| 2871 | + [ |
| 2872 | + input_gate_bias_op[0], |
| 2873 | + forget_gate_bias_op[0], |
| 2874 | + cell_gate_bias_op[0], |
| 2875 | + output_gate_bias_op[0], |
| 2876 | + ], |
| 2877 | + axis=0, |
| 2878 | + ) |
| 2879 | + |
| 2880 | + # Process weight matrix of hidden bias: |
| 2881 | + # b_hid (with the same shape as b_inp) |
| 2882 | + gate_bias_dtype = self.get_tensor_type_str(input_gate_bias.tensor.Type()) |
| 2883 | + weights_dict["b_hid"] = _op.split( |
| 2884 | + _op.const( |
| 2885 | + np.zeros(_infer_shape(weights_dict["b_inp"]), dtype=gate_bias_dtype), |
| 2886 | + dtype=gate_bias_dtype, |
| 2887 | + ), |
| 2888 | + 1, |
| 2889 | + )[0] |
| 2890 | + |
| 2891 | + outputs, _, _ = lstm_cell(input_seqs=X_steps, **weights_dict) |
| 2892 | + |
| 2893 | + output = _op.stack(outputs, axis=1) |
| 2894 | + return output |
| 2895 | + |
2718 | 2896 | def convert_batch_to_space_nd(self, op): |
2719 | 2897 | """batch_to_space_nd implementation.""" |
2720 | 2898 |
|
|
0 commit comments