@@ -2150,7 +2150,7 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=(-6, 6), int_quan
21502150 with tf .Graph ().as_default ():
21512151 in_data = array_ops .placeholder (shape = data .shape , dtype = data .dtype , name = "in" )
21522152 out = math_op (in_data )
2153- compare_tflite_with_tvm (data , ["in:0" ], [in_data ], [out ])
2153+ compare_tflite_with_tvm (data , ["in:0" ], [in_data ], [out ], experimental_new_converter = True )
21542154
21552155
21562156def _unary_elewise_create_model (math_op , data , offset = 0 , int_quant_dtype = tf .int8 ):
@@ -2400,6 +2400,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8):
24002400 return _test_unary_elemwise (nn_ops .elu , data , quantized , int_quant_dtype = int_quant_dtype )
24012401
24022402
2403+ #######################################################################
2404+ # Gelu
2405+ # ---
2406+
2407+
2408+ def _test_gelu (data , quantized , int_quant_dtype = tf .int8 ):
2409+ """One iteration of elu"""
2410+ return _test_unary_elemwise (nn_ops .gelu , data , quantized , int_quant_dtype = int_quant_dtype )
2411+
2412+
24032413def _test_forward_unary_elemwise (test_op , int_quant_dtype = None , quantized = True , negative = True ):
24042414 # input data
24052415 in_data , inq_data = [], []
@@ -2439,15 +2449,16 @@ def test_all_unary_elemwise():
24392449 _test_forward_unary_elemwise (_test_sin )
24402450 _test_forward_unary_elemwise (_test_neg )
24412451 _test_forward_unary_elemwise (_test_sqrt , negative = False )
2452+ _test_forward_unary_elemwise (_test_gelu , quantized = False )
24422453 # tensorflow version upgrade support
2443- if tf .__version__ < LooseVersion ("2.6.1" ):
2454+ if package_version . parse ( tf .VERSION ) < package_version . parse ("2.6.1" ):
24442455 _test_forward_unary_elemwise (_test_rsqrt , negative = False , int_quant_dtype = tf .uint8 )
24452456 else :
24462457 _test_forward_unary_elemwise (_test_rsqrt , negative = False , int_quant_dtype = tf .int8 )
24472458 # ceil and cos come with TFLite 1.14.0.post1 fbs schema
24482459 if package_version .parse (tf .VERSION ) >= package_version .parse ("1.14.0" ):
24492460 _test_forward_unary_elemwise (_test_ceil )
2450- if tf .__version__ < LooseVersion ("2.6.1" ):
2461+ if package_version . parse ( tf .VERSION ) < package_version . parse ("2.6.1" ):
24512462 _test_forward_unary_elemwise (_test_cos , quantized = False )
24522463 else :
24532464 _test_forward_unary_elemwise (_test_cos , int_quant_dtype = tf .int8 )
0 commit comments