2121from tensorflow .python .ops import init_ops
2222from tensorflow .core .framework import graph_pb2
2323
24- import nnvm . testing .tf
24+ import tvm . relay . testing .tf as tf_testing
2525
2626#######################################################################
2727# Generic run functions for TVM & tensorflow
@@ -784,9 +784,9 @@ def test_forward_pad():
784784def test_forward_inception_v3 ():
785785 '''test inception V3 model'''
786786 with tf .Graph ().as_default ():
787- graph_def = nnvm . testing . tf .get_workload ('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb' )
787+ graph_def = tf_testing .get_workload ('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb' )
788788 # Call the utility to import the graph definition into default graph.
789- graph_def = nnvm . testing . tf .ProcessGraphDefParam (graph_def )
789+ graph_def = tf_testing .ProcessGraphDefParam (graph_def )
790790
791791 data = np .random .uniform (size = (1 , 299 , 299 , 3 )).astype ('float32' )
792792
@@ -801,9 +801,9 @@ def test_forward_inception_v3():
801801def test_forward_inception_v1 ():
802802 '''test inception V1 model'''
803803 with tf .Graph ().as_default ():
804- graph_def = nnvm . testing . tf .get_workload ("InceptionV1/classify_image_graph_def-with_shapes.pb" )
804+ graph_def = tf_testing .get_workload ("InceptionV1/classify_image_graph_def-with_shapes.pb" )
805805 # Call the utility to import the graph definition into default graph.
806- graph_def = nnvm . testing . tf .ProcessGraphDefParam (graph_def )
806+ graph_def = tf_testing .ProcessGraphDefParam (graph_def )
807807
808808 # Build an image from random data.
809809 from PIL import Image
@@ -838,18 +838,18 @@ def test_forward_mobilenet():
838838 '''test mobilenet model'''
839839 # MobilenetV2
840840 with tf .Graph ().as_default ():
841- graph_def = nnvm . testing . tf .get_workload (
841+ graph_def = tf_testing .get_workload (
842842 "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz" ,
843843 "mobilenet_v2_1.4_224_frozen.pb" )
844844 # Call the utility to import the graph definition into default graph.
845- graph_def = nnvm . testing . tf .ProcessGraphDefParam (graph_def )
845+ graph_def = tf_testing .ProcessGraphDefParam (graph_def )
846846
847847 data = np .random .uniform (size = (1 , 224 , 224 , 3 )).astype ('float32' )
848848 out_node = 'MobilenetV2/Predictions/Reshape_1'
849849
850850 with tf .Session () as sess :
851851 # Add shapes to the graph.
852- graph_def = nnvm . testing . tf .AddShapesToGraphDef (sess , out_node )
852+ graph_def = tf_testing .AddShapesToGraphDef (sess , out_node )
853853 tf_output = run_tf_graph (sess , data , 'input:0' , out_node + ':0' )
854854 tvm_output = run_tvm_graph (graph_def , data , 'input' )
855855 tvm .testing .assert_allclose (np .squeeze (tvm_output [0 ]), np .squeeze (tf_output [0 ]), rtol = 1e-5 , atol = 1e-5 )
@@ -861,9 +861,9 @@ def test_forward_resnetv2():
861861 '''test resnet model'''
862862 if is_gpu_available ():
863863 with tf .Graph ().as_default ():
864- graph_def = nnvm . testing . tf .get_workload ("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb" )
864+ graph_def = tf_testing .get_workload ("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb" )
865865 # Call the utility to import the graph definition into default graph.
866- graph_def = nnvm . testing . tf .ProcessGraphDefParam (graph_def )
866+ graph_def = tf_testing .ProcessGraphDefParam (graph_def )
867867
868868 data = np .random .uniform (size = (128 , 224 , 224 , 3 )).astype ('float32' )
869869 out_node = 'ArgMax'
@@ -879,7 +879,7 @@ def test_forward_resnetv2():
879879dir (tf .contrib )
880880def test_forward_ptb ():
881881 '''test ptb model'''
882- config = nnvm . testing . tf .get_config ()
882+ config = tf_testing .get_config ()
883883 num_steps = config .num_steps
884884 num_hidden = config .hidden_size
885885 num_layers = config .num_layers
@@ -936,7 +936,7 @@ def _get_sample(data, state):
936936 "float32" )).asnumpy ()
937937 state_output = model .get_output (1 , tvm .nd .empty (out_state_shape ,
938938 "float32" )).asnumpy ()
939- sample = nnvm . testing . tf .pick_from_weight (tvm_output [0 ])
939+ sample = tf_testing .pick_from_weight (tvm_output [0 ])
940940
941941 return sample , state_output
942942
@@ -956,10 +956,10 @@ def _get_sample(data, state):
956956 return samples , state
957957
958958 with tf .Graph ().as_default ():
959- word_to_id , id_to_word , graph_def = nnvm . testing . tf .get_workload_ptb ()
959+ word_to_id , id_to_word , graph_def = tf_testing .get_workload_ptb ()
960960 vocab_size = len (word_to_id )
961961 # Call the utility to import the graph definition into default graph.
962- graph_def = nnvm . testing . tf .ProcessGraphDefParam (graph_def )
962+ graph_def = tf_testing .ProcessGraphDefParam (graph_def )
963963 sess = tf .Session ()
964964
965965 #TVM graph module creation
@@ -975,7 +975,7 @@ def _get_sample(data, state):
975975 for word in seed_for_sample ],
976976 in_state , params , cnt_sample )
977977 tvm_sample_str = _pretty_print (tvm_samples , False , id_to_word )
978- tf_samples , tf_state = nnvm . testing . tf .do_tf_sample (sess ,
978+ tf_samples , tf_state = tf_testing .do_tf_sample (sess ,
979979 [word_to_id [word ] for word in seed_for_sample ],
980980 in_state , cnt_sample )
981981 tf_sample_str = _pretty_print (tf_samples , False , id_to_word )
0 commit comments