Skip to content

Commit b488ade

Browse files
srkreddy1238AWS Neo
authored andcommitted
[RELAY][FRONTEND] Tensorflow frontend. (apache#2216)
* [RELAY][FRONTEND] Tensorflow frontend support. * * LSTM removed for a while. * * basic ops are good. * * nn wip * * wip * * python2.7 corrections. * * NN ops are good. * * e2e models working good * * all good except LSTM * * rebase, tutorials and CI trigger. * * CI errors. * * enable opt_level=3 * * Docstrings cleanup. testing.tf utils moved to relay from nnvm. * * tutorials update. * * LSTM work good now. * * Rebase * * CI error * * enable PTB. * * rebase. * * tutorials * Update python/tvm/relay/frontend/tensorflow.py Co-Authored-By: srkreddy1238 <[email protected]> * * review comments. * CI fix. * * review comments.
1 parent 55a48c5 commit b488ade

File tree

10 files changed

+2895
-25
lines changed

10 files changed

+2895
-25
lines changed

docs/frontend/tensorflow.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ instructions to generate protobuf from checkpoint.
2121

2222
### Add Shapes:
2323
While freezing of protobuf add additional option ```add_shapes=True``` to embed output shapes of each node into graph.
24-
You may use ```nnvm.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
24+
You may use ```tvm.relay.testing.tf.AddShapesToGraphDef``` from nnvm for the same.
2525
Please refer to [tensorflow tutorial](https://github.com/dmlc/tvm/blob/master/tutorials/nnvm/from_tensorflow.py).
2626

2727
### Explicit Shape:

nnvm/tests/python/frontend/tensorflow/test_forward.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tensorflow.python.ops import init_ops
2222
from tensorflow.core.framework import graph_pb2
2323

24-
import nnvm.testing.tf
24+
import tvm.relay.testing.tf as tf_testing
2525

2626
#######################################################################
2727
# Generic run functions for TVM & tensorflow
@@ -784,9 +784,9 @@ def test_forward_pad():
784784
def test_forward_inception_v3():
785785
'''test inception V3 model'''
786786
with tf.Graph().as_default():
787-
graph_def = nnvm.testing.tf.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
787+
graph_def = tf_testing.get_workload('InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb')
788788
# Call the utility to import the graph definition into default graph.
789-
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
789+
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
790790

791791
data = np.random.uniform(size=(1, 299, 299, 3)).astype('float32')
792792

@@ -801,9 +801,9 @@ def test_forward_inception_v3():
801801
def test_forward_inception_v1():
802802
'''test inception V1 model'''
803803
with tf.Graph().as_default():
804-
graph_def = nnvm.testing.tf.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
804+
graph_def = tf_testing.get_workload("InceptionV1/classify_image_graph_def-with_shapes.pb")
805805
# Call the utility to import the graph definition into default graph.
806-
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
806+
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
807807

808808
# Build an image from random data.
809809
from PIL import Image
@@ -838,18 +838,18 @@ def test_forward_mobilenet():
838838
'''test mobilenet model'''
839839
# MobilenetV2
840840
with tf.Graph().as_default():
841-
graph_def = nnvm.testing.tf.get_workload(
841+
graph_def = tf_testing.get_workload(
842842
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz",
843843
"mobilenet_v2_1.4_224_frozen.pb")
844844
# Call the utility to import the graph definition into default graph.
845-
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
845+
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
846846

847847
data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
848848
out_node = 'MobilenetV2/Predictions/Reshape_1'
849849

850850
with tf.Session() as sess:
851851
# Add shapes to the graph.
852-
graph_def = nnvm.testing.tf.AddShapesToGraphDef(sess, out_node)
852+
graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
853853
tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
854854
tvm_output = run_tvm_graph(graph_def, data, 'input')
855855
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
@@ -861,9 +861,9 @@ def test_forward_resnetv2():
861861
'''test resnet model'''
862862
if is_gpu_available():
863863
with tf.Graph().as_default():
864-
graph_def = nnvm.testing.tf.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
864+
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
865865
# Call the utility to import the graph definition into default graph.
866-
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
866+
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
867867

868868
data = np.random.uniform(size=(128, 224, 224, 3)).astype('float32')
869869
out_node = 'ArgMax'
@@ -879,7 +879,7 @@ def test_forward_resnetv2():
879879
dir(tf.contrib)
880880
def test_forward_ptb():
881881
'''test ptb model'''
882-
config = nnvm.testing.tf.get_config()
882+
config = tf_testing.get_config()
883883
num_steps = config.num_steps
884884
num_hidden = config.hidden_size
885885
num_layers = config.num_layers
@@ -936,7 +936,7 @@ def _get_sample(data, state):
936936
"float32")).asnumpy()
937937
state_output = model.get_output(1, tvm.nd.empty(out_state_shape,
938938
"float32")).asnumpy()
939-
sample = nnvm.testing.tf.pick_from_weight(tvm_output[0])
939+
sample = tf_testing.pick_from_weight(tvm_output[0])
940940

941941
return sample, state_output
942942

@@ -956,10 +956,10 @@ def _get_sample(data, state):
956956
return samples, state
957957

958958
with tf.Graph().as_default():
959-
word_to_id, id_to_word, graph_def = nnvm.testing.tf.get_workload_ptb()
959+
word_to_id, id_to_word, graph_def = tf_testing.get_workload_ptb()
960960
vocab_size = len(word_to_id)
961961
# Call the utility to import the graph definition into default graph.
962-
graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
962+
graph_def = tf_testing.ProcessGraphDefParam(graph_def)
963963
sess = tf.Session()
964964

965965
#TVM graph module creation
@@ -975,7 +975,7 @@ def _get_sample(data, state):
975975
for word in seed_for_sample],
976976
in_state, params, cnt_sample)
977977
tvm_sample_str = _pretty_print(tvm_samples, False, id_to_word)
978-
tf_samples, tf_state = nnvm.testing.tf.do_tf_sample(sess,
978+
tf_samples, tf_state = tf_testing.do_tf_sample(sess,
979979
[word_to_id[word] for word in seed_for_sample],
980980
in_state, cnt_sample)
981981
tf_sample_str = _pretty_print(tf_samples, False, id_to_word)

python/tvm/relay/frontend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from .tflite import from_tflite
1414
from .coreml import from_coreml
1515
from .caffe2 import from_caffe2
16+
from .tensorflow import from_tensorflow

0 commit comments

Comments
 (0)