@@ -2150,7 +2150,9 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=(-6, 6), int_quan
21502150 with tf .Graph ().as_default ():
21512151 in_data = array_ops .placeholder (shape = data .shape , dtype = data .dtype , name = "in" )
21522152 out = math_op (in_data )
2153- compare_tflite_with_tvm (data , ["in:0" ], [in_data ], [out ])
2153+ compare_tflite_with_tvm (
2154+ data , ["in:0" ], [in_data ], [out ], experimental_new_converter = True
2155+ )
21542156
21552157
21562158def _unary_elewise_create_model (math_op , data , offset = 0 , int_quant_dtype = tf .int8 ):
@@ -2400,6 +2402,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8):
24002402 return _test_unary_elemwise (nn_ops .elu , data , quantized , int_quant_dtype = int_quant_dtype )
24012403
24022404
2405+ #######################################################################
2406+ # Gelu
2407+ # ---
2408+
2409+
2410+ def _test_gelu (data , quantized , int_quant_dtype = tf .int8 ):
2411+ """One iteration of elu"""
2412+ return _test_unary_elemwise (nn_ops .gelu , data , quantized , int_quant_dtype = int_quant_dtype )
2413+
2414+
24032415def _test_forward_unary_elemwise (test_op , int_quant_dtype = None , quantized = True , negative = True ):
24042416 # input data
24052417 in_data , inq_data = [], []
@@ -2439,15 +2451,16 @@ def test_all_unary_elemwise():
24392451 _test_forward_unary_elemwise (_test_sin )
24402452 _test_forward_unary_elemwise (_test_neg )
24412453 _test_forward_unary_elemwise (_test_sqrt , negative = False )
2454+ _test_forward_unary_elemwise (_test_gelu , quantized = False )
24422455 # tensorflow version upgrade support
2443- if tf .__version__ < LooseVersion ("2.6.1" ):
2456+ if package_version . parse ( tf .VERSION ) < package_version . parse ("2.6.1" ):
24442457 _test_forward_unary_elemwise (_test_rsqrt , negative = False , int_quant_dtype = tf .uint8 )
24452458 else :
24462459 _test_forward_unary_elemwise (_test_rsqrt , negative = False , int_quant_dtype = tf .int8 )
24472460 # ceil and cos come with TFLite 1.14.0.post1 fbs schema
24482461 if package_version .parse (tf .VERSION ) >= package_version .parse ("1.14.0" ):
24492462 _test_forward_unary_elemwise (_test_ceil )
2450- if tf .__version__ < LooseVersion ("2.6.1" ):
2463+ if package_version . parse ( tf .VERSION ) < package_version . parse ("2.6.1" ):
24512464 _test_forward_unary_elemwise (_test_cos , quantized = False )
24522465 else :
24532466 _test_forward_unary_elemwise (_test_cos , int_quant_dtype = tf .int8 )
0 commit comments