1212from tensorflow .python .framework import constant_op
1313from tensorflow .python .framework import graph_util
1414from tensorflow .python .ops import nn_ops
15+ from tensorflow .python .ops import nn
1516from tensorflow .python .ops import array_ops
1617from tensorflow .python .ops import gen_array_ops
1718from tensorflow .python .ops import math_ops
@@ -878,7 +879,6 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
878879 sess ,
879880 sess .graph .as_graph_def (add_shapes = True ),
880881 ['lrn' ],)
881-
882882 tf_output = run_tf_graph (sess , inp_array , 'lrn0_data:0' , 'lrn:0' )
883883 tvm_output = run_tvm_graph (graph_def ,
884884 inp_array ,
@@ -889,6 +889,42 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
889889def test_forward_lrn ():
890890 _test_lrn ((1 , 3 , 20 , 20 ), 3 , 1 , 1.0 , 1.0 , 0.5 )
891891
892+ #######################################################################
893+ # l2_normalize
894+ # ------------
895+ def _test_l2_normalize (ishape , eps , axis ):
896+ """ testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
897+
898+ inp_array = np .random .uniform (size = ishape ).astype (np .float32 )
899+ inp_array .fill (1 )
900+
901+ with tf .Graph ().as_default ():
902+ in1 = tf .placeholder (shape = inp_array .shape , dtype = inp_array .dtype , name = "Placeholder" )
903+ nn .l2_normalize (in1 ,
904+ axis = axis ,
905+ epsilon = eps ,
906+ name = None ,
907+ dim = None )
908+
909+ with tf .Session () as sess :
910+ graph_def = tf .graph_util .convert_variables_to_constants (
911+ sess ,
912+ sess .graph .as_graph_def (add_shapes = True ),
913+ ['l2_normalize' ],
914+ )
915+ tf_output = run_tf_graph (sess , inp_array , 'Placeholder:0' , 'Placeholder:0' )
916+ tvm_output = run_tvm_graph (graph_def ,
917+ inp_array ,
918+ "Placeholder" ,
919+ tf_output .shape ,
920+ tf_output .dtype )
921+
922+ np .testing .assert_allclose (tf_output , tvm_output , atol = 1e-3 , rtol = 1e-3 )
923+ sess .close ()
924+ def test_forward_l2_normalize ():
925+ _test_l2_normalize ((1 , 3 , 20 , 20 ), 0.001 , (0 ,))
926+
927+ #######################################################################
892928# Main
893929# ----
894930if __name__ == '__main__' :
@@ -910,3 +946,4 @@ def test_forward_lrn():
910946 test_forward_gather ()
911947 test_forward_ptb ()
912948 test_forward_lrn ()
949+ test_forward_l2_normalize ()
0 commit comments