1- # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
1+ # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
22"""TF: Tensorflow frontend."""
33from __future__ import absolute_import as _abs
44from __future__ import print_function
77# Numpy support
88import numpy as np
99
10+ import tvm
1011from tvm import relay
1112from .. import ir_pass
1213from .. import expr as _expr
1314from .. import op as _op
14- from ... import nd as _nd
15- from .common import StrAttrsDict
16-
17- import tvm
18- #from .. import graph as _graph
19- #from .. compiler import graph_util, build_module
20- #from .common import get_nnvm_op, AttrConverter as AttrConvert
2115
2216__all__ = ['from_tensorflow' ]
2317
@@ -27,7 +21,7 @@ def _get_relay_op(op_name):
2721 except AttributeError :
2822 try :
2923 op = getattr (_op .nn , op_name )
30- except :
24+ except AttributeError :
3125 op = getattr (_op .image , op_name )
3226
3327 if not op :
@@ -161,15 +155,21 @@ def _required_attr(self, attr, key):
161155
162156def _get_pad_pair (input1d , kernel1d , stride1d ):
163157 if input1d % stride1d == 0 :
164- pad = tvm . select (( kernel1d - stride1d ) > 0 , ( kernel1d - stride1d ), relay . const ( 0 ) )
158+ pad = max ( kernel1d - stride1d , 0 )
165159 else :
166- pad = tvm . select (( kernel1d - (input1d % stride1d )) > 0 , ( kernel1d - ( input1d % stride1d )), relay . const ( 0 ) )
160+ pad = max ( kernel1d - (input1d % stride1d ), 0 )
167161
168- pad_before = pad // relay . const ( 2 )
162+ pad_before = pad // 2
169163 pad_after = pad - pad_before
170164
171165 return [pad_before , pad_after ]
172166
167+ def _get_name_hint (node ):
168+ name = ''
169+ if hasattr (node , "name_hint" ):
170+ name = node .name_hint
171+ return name
172+
173173def _math_name_picker (surfix ):
174174 def _impl (attr ):
175175 return 'broadcast_' + surfix
@@ -318,7 +318,7 @@ def _impl(inputs, attr, params):
318318 attr ['data_format' ] = "NCHW"
319319 attr ['strides' ] = [attr ['strides' ][ii ] for ii in (0 , 3 , 1 , 2 )]
320320 flip_layout = True
321- print ( "W Shape:" , weights_shape )
321+
322322 if attr ['data_format' ] == 'NHWC' :
323323 kernel_h , kernel_w , _ , depth_mult = weights_shape
324324 attr ['kernel_shape' ] = (weights_shape [0 ], weights_shape [1 ])
@@ -532,7 +532,7 @@ def _impl(inputs, attr, params):
532532
533533def _squeeze ():
534534 def _impl (inputs , attr , params ):
535- if 0 == len (attr ['squeeze_dims' ]):
535+ if len (attr ['squeeze_dims' ]) == 0 :
536536 attr ['squeeze_dims' ] = None
537537 return AttrCvt (
538538 op_name = "squeeze" ,
@@ -591,7 +591,7 @@ def _impl(inputs, attr, params):
591591
592592def _relu6 ():
593593 def _impl (inputs , attr , params ):
594- return _op .clip (inputs [0 ], a_min = 0 , a_max = 6 , name = attr [ '_node_name' ] )
594+ return _op .clip (inputs [0 ], a_min = 0 , a_max = 6 )
595595 return _impl
596596
597597def _shape ():
@@ -647,11 +647,10 @@ def _impl(inputs, attr, params):
647647 new_input = []
648648 new_input .append (inputs .pop (0 ))
649649 new_input .append (inputs .pop (0 ))
650- return AttrCvt (
651- op_name = "take" ,
652- extras = {'axis' :axis },
653- ignores = ['Tindices' , 'Tparams' , 'validate_indices' , \
654- 'Taxis' , '_class' ])(new_input , attr )
650+ return AttrCvt (op_name = "take" ,
651+ extras = {'axis' : tvm .const (axis )},
652+ ignores = ['Tindices' , 'Tparams' , 'validate_indices' , \
653+ 'Taxis' , '_class' ])(new_input , attr )
655654 return _impl
656655
657656def _infer_out_shapes (inputs , params ):
@@ -744,7 +743,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
744743 fshape_indices = None
745744 if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask :
746745 begin , end , stride , fshape_indices = _transform_mask (stride_dim , ellipsis_mask )
747- out = _op .strided_slice (inputs [0 ], begin = begin , end = end , stride = stride )
746+ out = _op .strided_slice (inputs [0 ], begin = begin , end = end , strides = stride )
748747 out_shape = _infer_out_shapes (out , params )[0 ]
749748 if not fshape_indices :
750749 fshape_indices = range (len (out_shape ))
@@ -758,7 +757,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
758757 pass
759758 else :
760759 final_output .append (out_shape [gather_index ])
761- return _op .reshape (out , shape = tuple (final_output ))
760+ return _op .reshape (out , newshape = tuple (final_output ))
762761 return _impl
763762
764763def _pad (name ):
@@ -785,9 +784,12 @@ def _transpose():
785784 def _impl (inputs , attr , params ):
786785 # If perm is not specified, axes is left empty,
787786 # otherwise its value is get from params
788- param_name = inputs [1 ].name_hint
789- axes = params .get (param_name , tvm .nd .array ([])).asnumpy ()
790- return _op .transpose (inputs [0 ], axes = tuple (axes ))
787+ param_name = _get_name_hint (inputs [1 ])
788+ if param_name in params :
789+ axes = tuple (params .get (param_name ).asnumpy ())
790+ else :
791+ axes = None
792+ return _op .transpose (inputs [0 ], axes = axes )
791793 return _impl
792794
793795def _rank ():
@@ -799,7 +801,7 @@ def _impl(inputs, attr, params):
799801 params [name ] = tvm .nd .array ([len (input_shapes [0 ])])
800802 return [_expr .var (name ,
801803 shape = params [name ].shape ,
802- dtype = params [ name ]. dtype )]
804+ dtype = 'int32' )]
803805
804806 return _impl
805807
@@ -813,20 +815,22 @@ def _impl(inputs, attr, params):
813815 params [name ] = tvm .nd .array ([start , limit , delta ])
814816 return [_expr .var (name ,
815817 shape = params [name ].shape ,
816- dtype = params [ name ]. dtype )]
818+ dtype = 'int32' )]
817819 return _impl
818820
819821def _elu ():
820822 def _impl (inputs , attr , params ):
821823 alpha = relay .const (- 1.0 , attr ['T' ].name )
822- return alpha * _op .nn .relu (relay .const (1 , attr ['T' ].name ) - _op .exp (inputs [0 ])) + _op .nn .relu (inputs [0 ])
824+ return alpha * _op .nn .relu (relay .const (1 , attr ['T' ].name ) \
825+ - _op .exp (inputs [0 ])) + _op .nn .relu (inputs [0 ])
823826 return _impl
824827
825828def _selu ():
826829 def _impl (inputs , attr , params ):
827830 alpha = relay .const (- 1.6732632423543772848170429916717 )
828831 gamma = relay .const (1.0507009873554804934193349852946 )
829- return gamma * (alpha * _op .nn .relu (relay .const (1 , attr ['T' ].name ) - _op .exp (inputs [0 ])) + _op .nn .relu (inputs [0 ]))
832+ return gamma * (alpha * _op .nn .relu (relay .const (1 , attr ['T' ].name ) \
833+ - _op .exp (inputs [0 ])) + _op .nn .relu (inputs [0 ]))
830834 return _impl
831835
832836def _mean ():
@@ -873,7 +877,7 @@ def _impl(inputs, attr, params):
873877 'MatMul' : _matmul (),
874878 'MaxPool' : _pooling ('max_pool' ),
875879 'Add' : _elemwise ('add' ),
876- 'Sub' : _elemwise ('sub ' ),
880+ 'Sub' : _elemwise ('subtract ' ),
877881 'Mul' : _elemwise ('multiply' ),
878882 'Maximum' : _elemwise ('max' ),
879883 'Minimum' : _elemwise ('min' ),
@@ -971,10 +975,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
971975 raise NotImplementedError ( \
972976 "The following operators are not implemented: {}" .format (missing_operators ))
973977
974- final_op = None
975978 # Parse the nodes to re-create TF graph using Symbol API of NNVM
976979 for node in graph .node :
977- print ("Node: " , node .name , "Node Op:" , node .op )
978980 # Tensorflow doesn't have seperate list for params extraction.
979981 # Operator name 'Const' is treated as a parameter to build NNVM params dict.
980982
@@ -1070,9 +1072,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
10701072 out = op
10711073 out = out [0 ] if len (out ) == 1 else _expr .Tuple (out )
10721074 func = _expr .Function (ir_pass .free_vars (out ), out )
1073- print ("OP:" , op )
1074- print ("Func:" , func )
1075- print ("Shape:" , relay .ir_pass .infer_type (op [0 ]).checked_type )
10761075
10771076 return func , self ._params
10781077
0 commit comments