Skip to content

Commit f2ea7ee

Browse files
sergei-mironovSergey Mironov
authored andcommitted
[NNVM] Tensorflow, Always use first placeholder as input node, if available
Minor changes: * Report all missing operators in advance * Describe logic in comments
1 parent 1a3264e commit f2ea7ee

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ def from_tensorflow(self, graph):
902902
Follow the tensorflow graph definition to parse and convert it to NNVM.
903903
Some of the assumptions listed below.
904904
905-
-> First Const or Placeholder node will be considered as graph input.
905+
-> First Placeholder or Const node will be considered as graph input.
906906
-> Rest all Const nodes are params.
907907
-> Last node is assumed as graph output.
908908
-> _output_shapes : Attribute should present in the tenserflow forzen graph.
@@ -924,17 +924,39 @@ def from_tensorflow(self, graph):
924924
params : dict
925925
A dict of name: tvm.nd.array pairs, used as pretrained weights
926926
"""
927-
# Parse throught all nodes and start extracting
928-
# params aka Const nodes
929-
# input nodes : First const node
930-
# normal nodes : other normal nodes
931927

932928
try:
933929
from tensorflow.python.framework import tensor_util
934930
except ImportError as e:
935931
raise ImportError(
936932
"Unable to import tensorflow which is required {}".format(e))
937933

934+
# Parse the nodes briefly to determine the preconditions:
935+
# a. Placeholders
936+
# b. Missing operations to report in advance
937+
has_placeholders = False
938+
missing_operators = set()
939+
for node in graph.node:
940+
if node.op == "Placeholder":
941+
has_placeholders = True
942+
elif node.op == "Const":
943+
pass
944+
else:
945+
if node.op in _identity_list:
946+
pass
947+
elif node.op in _convert_map:
948+
pass
949+
else:
950+
missing_operators.add(node.op)
951+
952+
if missing_operators:
953+
raise NotImplementedError( \
954+
"The following operators are not implemented: {}".format(missing_operators))
955+
956+
# Parse the nodes to re-create TF graph using Symbol API of NNVM
957+
# * Input node will be produced by first placeholder node, or by first
958+
# Const node if there are no placeholders
959+
# * Other nodes will be mapped using `_convert_map` and other translation facilities
938960
for node in graph.node:
939961
# Tensorflow doesn't have seperate list for params extraction.
940962
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
@@ -956,7 +978,7 @@ def from_tensorflow(self, graph):
956978
"Please freeze the graph with add_shapes=True")
957979
elif node.op == "Const":
958980
# Assuming first Const node as Graph Input node
959-
if self._input_node == '':
981+
if self._input_node == '' and not has_placeholders:
960982
self._input_node = node.name
961983
self._num_input += 1
962984
self._nodes[node.name] = _sym.Variable(name=node.name)

0 commit comments

Comments
 (0)