From ee924bed20cde197f3c8cf52eae7d759cb001d93 Mon Sep 17 00:00:00 2001 From: SRK Reddy Date: Mon, 13 May 2019 16:50:04 +0530 Subject: [PATCH] [TENSORLFOW] PlaceholderWithDefault (limited) implementation. --- python/tvm/relay/frontend/tensorflow.py | 6 +++--- .../frontend/tensorflow/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 48f78837c525..571ab2f827e0 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1728,7 +1728,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): for node in graph.node: node_name_prefix = node.name.rsplit('/', 1)[0] control_flow_node_map[node_name_prefix].add(node.op) - if node.op == 'Placeholder': + if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1788,7 +1788,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): attr = self._parse_attr(node.attr) - elif node.op != "Placeholder": + elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] @@ -1913,7 +1913,7 @@ def _parse_import_prerequisites(self, graph): """ missing_operators = set() for node in graph.node: - if node.op == "Placeholder": + if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 58bbdab02b84..fd96da57b431 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1527,6 +1527,24 @@ def test_forward_reduce_prod(): _test_forward_reduce_prod((5, 5), 0, True) _test_forward_reduce_prod((5, 5), 1, True) + +####################################################################### +# PlaceholderWithDefault +# ---------------------- +def test_placeholder(): + with tf.Graph().as_default(): + in_data1 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) + var1 = tf.Variable(in_data1, name='in1') + var2 = array_ops.placeholder_with_default(var1, None, name='place1') + + in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) + place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2') + + out1 = tf.math.add(var1, var2, name='out1') + out2 = tf.math.add(out1, place1, name='out2') + + compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) + ####################################################################### # Main # ---- @@ -1574,6 +1592,7 @@ def test_forward_reduce_prod(): test_forward_multi_input() test_forward_multi_output() test_forward_variable() + test_placeholder() # NN test_forward_convolution()