Skip to content

Commit b498016

Browse files
committed
[TFLite] Add support for GELU conversion
This commit adds support for converting a TFLite fp32 GELU operation to Relay. Also includes some neighbouring cleanup of version checks to silence warnings. Change-Id: Ic43b1525c4b80cf7f47281c52bb9a8f2643c4073
1 parent 278a6af commit b498016

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(self, model, subgraph, exp_tab):
109109
"GATHER_ND": self.convert_gather_nd,
110110
"GREATER_EQUAL": self.convert_greater_equal,
111111
"GREATER": self.convert_greater,
112+
"GELU": self.convert_gelu,
112113
"HARD_SWISH": self.convert_hard_swish,
113114
"L2_NORMALIZATION": self.convert_l2_normalization,
114115
"L2_POOL_2D": self.convert_l2_pool2d,
@@ -1287,6 +1288,24 @@ def convert_elu(self, op):
12871288

12881289
return out
12891290

1291+
def convert_gelu(self, op):
1292+
if self.is_quantized(op):
1293+
raise tvm.error.OpNotImplemented(
1294+
"The TFLite to Relay converter does not support quantized GELU operator yet."
1295+
)
1296+
1297+
input_tensors = self.get_input_tensors(op)
1298+
assert len(input_tensors) == 1, "input tensors length should be 1"
1299+
1300+
input_tensor = input_tensors[0]
1301+
in_expr = self.get_expr(input_tensor.tensor_idx)
1302+
in_type = self.get_tensor_type_str(input_tensor.tensor.Type())
1303+
1304+
return in_expr * (
1305+
_expr.const(0.5, dtype=in_type)
1306+
+ _op.erf(in_expr * _expr.const(0.5**0.5, dtype=in_type)) * _expr.const(0.5, dtype=in_type)
1307+
)
1308+
12901309
def convert_square(self, op):
12911310
"""Convert TFLite SQUARE"""
12921311
input_tensors = self.get_input_tensors(op)

tests/python/frontend/tflite/test_forward.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,7 +2150,7 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=(-6, 6), int_quan
21502150
with tf.Graph().as_default():
21512151
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in")
21522152
out = math_op(in_data)
2153-
compare_tflite_with_tvm(data, ["in:0"], [in_data], [out])
2153+
compare_tflite_with_tvm(data, ["in:0"], [in_data], [out], experimental_new_converter=True)
21542154

21552155

21562156
def _unary_elewise_create_model(math_op, data, offset=0, int_quant_dtype=tf.int8):
@@ -2400,6 +2400,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8):
24002400
return _test_unary_elemwise(nn_ops.elu, data, quantized, int_quant_dtype=int_quant_dtype)
24012401

24022402

2403+
#######################################################################
2404+
# Gelu
2405+
# ---
2406+
2407+
2408+
def _test_gelu(data, quantized, int_quant_dtype=tf.int8):
2409+
"""One iteration of elu"""
2410+
return _test_unary_elemwise(nn_ops.gelu, data, quantized, int_quant_dtype=int_quant_dtype)
2411+
2412+
24032413
def _test_forward_unary_elemwise(test_op, int_quant_dtype=None, quantized=True, negative=True):
24042414
# input data
24052415
in_data, inq_data = [], []
@@ -2439,15 +2449,16 @@ def test_all_unary_elemwise():
24392449
_test_forward_unary_elemwise(_test_sin)
24402450
_test_forward_unary_elemwise(_test_neg)
24412451
_test_forward_unary_elemwise(_test_sqrt, negative=False)
2452+
_test_forward_unary_elemwise(_test_gelu, quantized=False)
24422453
# tensorflow version upgrade support
2443-
if tf.__version__ < LooseVersion("2.6.1"):
2454+
if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
24442455
_test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.uint8)
24452456
else:
24462457
_test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.int8)
24472458
# ceil and cos come with TFLite 1.14.0.post1 fbs schema
24482459
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
24492460
_test_forward_unary_elemwise(_test_ceil)
2450-
if tf.__version__ < LooseVersion("2.6.1"):
2461+
if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
24512462
_test_forward_unary_elemwise(_test_cos, quantized=False)
24522463
else:
24532464
_test_forward_unary_elemwise(_test_cos, int_quant_dtype=tf.int8)

0 commit comments

Comments
 (0)