@@ -168,81 +168,7 @@ def _impl(inputs, attr, params):
168168 custom_check = _dimension_constraint ())(inputs , attr )
169169 return _impl
170170
171- def _conv ():
172- def _impl (inputs , attr , params ):
173- attr ['data_format' ] = attr ['data_format' ].decode ("utf-8" )
174-
175- # Extract kernel shape from params
176- conv_param_weights = params [inputs [1 ].list_output_names ()[0 ]]
177-
178- if attr ['data_format' ] == 'NHWC' :
179- attr ['kernel_shape' ] = (conv_param_weights .shape [0 ], conv_param_weights .shape [1 ])
180- attr ['channels' ] = conv_param_weights .shape [3 ]
181- if 'dilations' in attr :
182- attr ['dilations' ] = (attr ['dilations' ][0 ], attr ['dilations' ][1 ])
183- elif attr ['data_format' ] == 'NCHW' :
184- attr ['kernel_shape' ] = (conv_param_weights .shape [2 ], conv_param_weights .shape [3 ])
185- attr ['channels' ] = conv_param_weights .shape [1 ]
186- if 'dilations' in attr :
187- attr ['dilations' ] = (attr ['dilations' ][2 ], attr ['dilations' ][3 ])
188- else :
189- raise TypeError ("Unsupported data format type : {}" .format (attr ['data_format' ]))
190-
191- # Fix strides
192- attr ['strides' ] = (attr ['strides' ][1 ], attr ['strides' ][2 ])
193-
194- # Fix padding
195- input_shapes = attr ['_input_shapes' ][inputs [0 ]]
196- attr ['padding' ] = attr ['padding' ].decode ("utf-8" )
197-
198- if attr ['padding' ] == 'VALID' :
199- attr ['padding' ] = [0 , 0 ]
200- elif attr ['padding' ] == 'SAME' :
201- stride_h , stride_w = attr ['strides' ]
202- kernel_h , kernel_w = attr ['kernel_shape' ]
203- if attr ['data_format' ] == 'NHWC' :
204- in_h = input_shapes [0 ][1 ]
205- in_w = input_shapes [0 ][2 ]
206- else :
207- in_h = input_shapes [0 ][2 ]
208- in_w = input_shapes [0 ][3 ]
209-
210- pad_v = _get_pad_pair (in_h , kernel_h , stride_h )
211- pad_h = _get_pad_pair (in_w , kernel_w , stride_w )
212-
213- if attr ['data_format' ] == 'NHWC' :
214- inputs [0 ] = _sym .pad (data = inputs [0 ],
215- pad_width = ((0 , 0 ),
216- (pad_v [0 ], pad_v [1 ]),
217- (pad_h [0 ], pad_h [1 ]),
218- (0 , 0 )))
219- else :
220- inputs [0 ] = _sym .pad (data = inputs [0 ],
221- pad_width = ((0 , 0 ),
222- (0 , 0 ),
223- (pad_v [0 ], pad_v [1 ]),
224- (pad_h [0 ], pad_h [1 ])))
225-
226- attr ['padding' ] = [0 , 0 ]
227-
228- else :
229- raise TypeError ("Unsupported padding type : {}" .format (attr ['padding' ]))
230-
231- if 'kernel_layout' not in attr :
232- attr ['kernel_layout' ] = 'HWIO' if attr ['data_format' ] == 'NHWC' else 'OIHW'
233-
234- return AttrCvt (
235- op_name = _dimension_picker ('conv' ),
236- transforms = {
237- 'kernel_shape' : 'kernel_size' ,
238- 'data_format' : 'layout' ,
239- 'dilations' : ('dilation' , (0 , 0 )),
240- 'group' : ('groups' , 1 )},
241- extras = {'use_bias' : len (inputs ) == 3 },
242- custom_check = _dimension_constraint ())(inputs , attr )
243- return _impl
244-
245- def _depthwise_conv ():
171+ def _conv (opname ):
246172 def _impl (inputs , attr , params ):
247173 attr ['data_format' ] = attr ['data_format' ].decode ("utf-8" )
248174 input_shapes = attr ['_input_shapes' ][inputs [0 ]]
@@ -253,24 +179,33 @@ def _impl(inputs, attr, params):
253179 if attr ['data_format' ] == 'NHWC' :
254180 kernel_h , kernel_w , _ , depth_mult = conv_param_weights .shape
255181 attr ['kernel_shape' ] = (conv_param_weights .shape [0 ], conv_param_weights .shape [1 ])
256- attr ['channels' ] = input_shapes [0 ][3 ] * depth_mult
182+ if opname == 'conv' :
183+ attr ['channels' ] = conv_param_weights .shape [3 ]
184+ else :
185+ attr ['channels' ] = input_shapes [0 ][3 ] * depth_mult
186+
257187 if 'dilations' in attr :
258188 attr ['dilations' ] = (attr ['dilations' ][0 ], attr ['dilations' ][1 ])
259189 elif attr ['data_format' ] == 'NCHW' :
260190 depth_mult , _ , kernel_h , kernel_w = conv_param_weights .shape
261191 attr ['kernel_shape' ] = (conv_param_weights .shape [2 ], conv_param_weights .shape [3 ])
262- attr ['channels' ] = input_shapes [0 ][1 ] * depth_mult
192+ if opname == 'conv' :
193+ attr ['channels' ] = conv_param_weights .shape [1 ]
194+ else :
195+ attr ['channels' ] = input_shapes [0 ][1 ] * depth_mult
196+
263197 if 'dilations' in attr :
264198 attr ['dilations' ] = (attr ['dilations' ][2 ], attr ['dilations' ][3 ])
265199 else :
266200 raise TypeError ("Unsupported data format type : {}" .format (attr ['data_format' ]))
267201
202+
203+ if opname == 'depthwise' :
204+ attr ['groups' ] = attr ['channels' ]
205+
268206 # Fix strides
269207 attr ['strides' ] = (attr ['strides' ][1 ], attr ['strides' ][2 ])
270208
271- # Fix groups
272- attr ['groups' ] = attr ['channels' ]
273-
274209 # Fix padding
275210 attr ['padding' ] = attr ['padding' ].decode ("utf-8" )
276211
@@ -308,7 +243,10 @@ def _impl(inputs, attr, params):
308243 raise TypeError ("Unsupported padding type : {}" .format (attr ['padding' ]))
309244
310245 if 'kernel_layout' not in attr :
311- attr ['kernel_layout' ] = 'HWOI' if attr ['data_format' ] == 'NHWC' else 'OIHW'
246+ if opname == 'conv' :
247+ attr ['kernel_layout' ] = 'HWIO' if attr ['data_format' ] == 'NHWC' else 'OIHW'
248+ else :
249+ attr ['kernel_layout' ] = 'HWOI' if attr ['data_format' ] == 'NHWC' else 'OIHW'
312250
313251 return AttrCvt (
314252 op_name = _dimension_picker ('conv' ),
@@ -687,7 +625,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
687625 'CheckNumerics' : _check_numerics (),
688626 'Concat' : _concat (),
689627 'ConcatV2' : _concatV2 (),
690- 'Conv2D' : _conv (),
628+ 'Conv2D' : _conv ('conv' ),
691629 'DecodeJpeg' : _decode_image (),
692630 'ExpandDims' : _expand_dims (),
693631 'Identity' : _identity (),
@@ -704,7 +642,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
704642 'Squeeze' : _squeeze (),
705643 'FusedBatchNorm' : _fused_batch_norm (),
706644 'Relu6' : _relu6 (),
707- 'DepthwiseConv2dNative' : _depthwise_conv ( ),
645+ 'DepthwiseConv2dNative' : _conv ( 'depthwise' ),
708646 'Shape' : _shape (),
709647 'Sigmoid' : AttrCvt ('sigmoid' ),
710648 'Fill' : _fill (),
@@ -895,28 +833,6 @@ def _get_abs_layer_name(node):
895833 params , num_layers )
896834 return sym
897835
898-
899- def _parse_import_prerequisites (graph ):
900- """ Calculate the named preconditions from TensorFlow `graph`.
901- Return prerequisites for parsing:
902- a. Set of operator names which don't have their mapping in TVM, i.e.
903- which are not supported
904- """
905- missing_operators = set ()
906- for node in graph .node :
907- if node .op == "Placeholder" :
908- pass
909- elif node .op == "Const" :
910- pass
911- else :
912- if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
913- pass
914- else :
915- missing_operators .add (node .op )
916-
917- return missing_operators
918-
919-
920836class GraphProto (object ):
921837 """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
922838 Definition:
@@ -925,12 +841,8 @@ class GraphProto(object):
925841 def __init__ (self ):
926842 self ._nodes = {}
927843 self ._params = {}
928- self ._renames = {}
929- self ._replacements = {}
930844 self ._output_shapes = {}
931- self ._num_input = 0
932845 self ._num_param = 0
933- self ._input_node = ''
934846 self ._num_rnn_layer = False
935847
936848 def from_tensorflow (self , graph ):
@@ -969,7 +881,7 @@ def from_tensorflow(self, graph):
969881 raise ImportError (
970882 "Unable to import tensorflow which is required {}" .format (e ))
971883
972- missing_operators = _parse_import_prerequisites (graph )
884+ missing_operators = self . _parse_import_prerequisites (graph )
973885
974886 if missing_operators :
975887 raise NotImplementedError ( \
@@ -979,58 +891,42 @@ def from_tensorflow(self, graph):
979891 for node in graph .node :
980892 # Tensorflow doesn't have seperate list for params extraction.
981893 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
894+
982895 input_shapes = {}
896+
897+ attr = self ._parse_attr (node .attr )
898+
899+ #Variable converted to Const will not have only value attr
900+ if 'value' in attr and node .op == 'Const' :
901+ tensor_value = attr ['value' ]
902+ self ._output_shapes [node .name ] = \
903+ [tensor_util .TensorShapeProtoToList ( \
904+ tensor_value .tensor_shape )]
905+ elif '_output_shapes' in attr :
906+ self ._output_shapes [node .name ] = \
907+ [tensor_util .TensorShapeProtoToList (shape ) \
908+ for shape in attr ['_output_shapes' ]]
909+ else :
910+ raise NotImplementedError ( \
911+ "Please freeze the graph with add_shapes=True" )
912+
983913 if node .op == "Placeholder" :
984- self ._input_node = node .name
985- self . _num_input += 1
914+ self ._nodes [ node . name ] = _sym . Variable ( name = node .name ,
915+ shape = self . _output_shapes [ node . name ][ 0 ])
986916
987- try :
988- self ._output_shapes [node .name ] = \
989- [tensor_util .TensorShapeProtoToList (shape ) \
990- for shape in self ._parse_attr (node .attr )['_output_shapes' ]]
991- self ._nodes [node .name ] = _sym .Variable (name = node .name ,
992- shape = self ._output_shapes [node .name ][0 ])
993- input_shapes [self ._nodes [node .name ]] = self ._output_shapes [node .name ]
994- except KeyError :
995- raise NotImplementedError ( \
996- "Please freeze the graph with add_shapes=True" )
917+ #input_shapes[self._nodes[node.name]] = self._output_shapes[node.name]
997918 elif node .op == "Const" :
998- if self ._input_node == '' :
999- self ._input_node = node .name
1000- self ._num_input += 1
1001- self ._nodes [node .name ] = _sym .Variable (name = node .name )
1002- else :
1003- # Rest all nodes are Param nodes, lets parse
1004- self ._num_param += 1
1005- for key , value in node .attr .items ():
1006- self ._parse_param (key , value , node .name )
1007- if node .name not in self ._nodes :
1008- raise NotImplementedError ( \
1009- "Const {} couldn't be converted to Param." .format (node .name ))
1010- attr = self ._parse_attr (node .attr )
1011- #Variable converted to Const will not have only value attr
1012- if 'value' in attr :
1013- tensor_value = attr ['value' ]
1014- self ._output_shapes [node .name ] = \
1015- [tensor_util .TensorShapeProtoToList ( \
1016- tensor_value .tensor_shape )]
1017- elif '_output_shapes' in attr :
1018- self ._output_shapes [node .name ] = \
1019- [tensor_util .TensorShapeProtoToList (shape ) \
1020- for shape in self ._parse_attr (node .attr )['_output_shapes' ]]
1021- else :
919+ # All Const nodes are Param nodes, lets parse
920+ self ._num_param += 1
921+ for key , value in node .attr .items ():
922+ self ._parse_param (key , value , node .name )
923+ if node .name not in self ._nodes :
1022924 raise NotImplementedError ( \
1023- "Please freeze the graph with add_shapes=True" )
1024- else :
925+ "Const {} couldn't be converted to Param." . format ( node . name ) )
926+
1025927 attr = self ._parse_attr (node .attr )
1026- try :
1027- self ._output_shapes [node .name ] = \
1028- [tensor_util .TensorShapeProtoToList (shape ) \
1029- for shape in attr ['_output_shapes' ]]
1030- except KeyError :
1031- raise NotImplementedError ( \
1032- "Please freeze the graph with add_shapes=True" )
1033928
929+ else :
1034930 # Pass the parsed shapes instead
1035931 attr ["_output_shapes" ] = self ._output_shapes [node .name ]
1036932
@@ -1045,11 +941,12 @@ def from_tensorflow(self, graph):
1045941 if ":" in node .input [0 ]:
1046942 in_name , _ = node .input [0 ].split (':' )
1047943 node .input [0 ] = in_name
944+
945+ # Fill shapes for all inputs in a list
1048946 try :
1049947 inputs = [self ._nodes [i ] for i in node .input ]
1050948 for i in node .input :
1051- if i not in self ._params :
1052- input_shapes [self ._nodes [i ]] = self ._output_shapes [i ]
949+ input_shapes [self ._nodes [i ]] = self ._output_shapes [i ]
1053950 attr ['_input_shapes' ] = input_shapes
1054951 except KeyError :
1055952 # TODO: Need to find clean way to handle '^CheckNumerics'
@@ -1061,18 +958,40 @@ def from_tensorflow(self, graph):
1061958 # Assuming only one output.
1062959 self ._nodes [node .name ] = op
1063960 node_output = op
961+
1064962 # Assume the final node is the output node
1065963 out = node_output
1066964
1067965 #Add the RNN outputs also with 'head' nodes of the nnvm graph
1068966 if self ._num_rnn_layer :
1069967 out_rnn = _sym .concatenate (* self ._out_rnn , axis = 0 )
1070968 out = [out , out_rnn ]
969+
1071970 if isinstance (out , list ):
1072971 out = _sym .Group (out )
1073972
1074973 return out , self ._params
1075974
975+ def _parse_import_prerequisites (self , graph ):
976+ """ Calculate the named preconditions from TensorFlow `graph`.
977+ Return prerequisites for parsing:
978+ a. Set of operator names which don't have their mapping in TVM, i.e.
979+ which are not supported
980+ """
981+ missing_operators = set ()
982+ for node in graph .node :
983+ if node .op == "Placeholder" :
984+ pass
985+ elif node .op == "Const" :
986+ pass
987+ else :
988+ if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
989+ pass
990+ else :
991+ missing_operators .add (node .op )
992+
993+ return missing_operators
994+
1076995 def _parse_param (self , key , value , name ):
1077996 try :
1078997 from tensorflow .python .framework import tensor_util
@@ -1082,6 +1001,13 @@ def _parse_param(self, key, value, name):
10821001
10831002 if key == 'value' :
10841003 np_array = tensor_util .MakeNdarray (value .tensor )
1004+
1005+ if np_array .dtype == np .dtype (object ):
1006+ # Object types are generally tensorflow DT_STRING (DecodeJpeg op).
1007+ # Just leave it as placeholder.
1008+ self ._nodes [name ] = _sym .Variable (name = name )
1009+ return
1010+
10851011 array_ndim = len (np_array .shape )
10861012 if array_ndim == 0 :
10871013 new_array = np .empty ([1 ], dtype = np_array .dtype )
0 commit comments