Skip to content

Commit 92ec3f8

Browse files
committed
* wip
1 parent d097a66 commit 92ec3f8

File tree

2 files changed

+47
-47
lines changed

2 files changed

+47
-47
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
1+
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
22
"""TF: Tensorflow frontend."""
33
from __future__ import absolute_import as _abs
44
from __future__ import print_function
@@ -7,17 +7,11 @@
77
# Numpy support
88
import numpy as np
99

10+
import tvm
1011
from tvm import relay
1112
from .. import ir_pass
1213
from .. import expr as _expr
1314
from .. import op as _op
14-
from ... import nd as _nd
15-
from .common import StrAttrsDict
16-
17-
import tvm
18-
#from .. import graph as _graph
19-
#from .. compiler import graph_util, build_module
20-
#from .common import get_nnvm_op, AttrConverter as AttrConvert
2115

2216
__all__ = ['from_tensorflow']
2317

@@ -27,7 +21,7 @@ def _get_relay_op(op_name):
2721
except AttributeError:
2822
try:
2923
op = getattr(_op.nn, op_name)
30-
except:
24+
except AttributeError:
3125
op = getattr(_op.image, op_name)
3226

3327
if not op:
@@ -161,15 +155,21 @@ def _required_attr(self, attr, key):
161155

162156
def _get_pad_pair(input1d, kernel1d, stride1d):
163157
if input1d % stride1d == 0:
164-
pad = tvm.select((kernel1d - stride1d) > 0, (kernel1d - stride1d), relay.const(0))
158+
pad = max(kernel1d - stride1d, 0)
165159
else:
166-
pad = tvm.select((kernel1d - (input1d % stride1d)) > 0, (kernel1d - (input1d % stride1d)), relay.const(0))
160+
pad = max(kernel1d - (input1d % stride1d), 0)
167161

168-
pad_before = pad // relay.const(2)
162+
pad_before = pad // 2
169163
pad_after = pad - pad_before
170164

171165
return [pad_before, pad_after]
172166

167+
def _get_name_hint(node):
168+
name = ''
169+
if hasattr(node, "name_hint"):
170+
name = node.name_hint
171+
return name
172+
173173
def _math_name_picker(surfix):
174174
def _impl(attr):
175175
return 'broadcast_' + surfix
@@ -318,7 +318,7 @@ def _impl(inputs, attr, params):
318318
attr['data_format'] = "NCHW"
319319
attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)]
320320
flip_layout = True
321-
print("W Shape:", weights_shape)
321+
322322
if attr['data_format'] == 'NHWC':
323323
kernel_h, kernel_w, _, depth_mult = weights_shape
324324
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
@@ -532,7 +532,7 @@ def _impl(inputs, attr, params):
532532

533533
def _squeeze():
534534
def _impl(inputs, attr, params):
535-
if 0 == len(attr['squeeze_dims']):
535+
if len(attr['squeeze_dims']) == 0:
536536
attr['squeeze_dims'] = None
537537
return AttrCvt(
538538
op_name="squeeze",
@@ -591,7 +591,7 @@ def _impl(inputs, attr, params):
591591

592592
def _relu6():
593593
def _impl(inputs, attr, params):
594-
return _op.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name'])
594+
return _op.clip(inputs[0], a_min=0, a_max=6)
595595
return _impl
596596

597597
def _shape():
@@ -647,11 +647,10 @@ def _impl(inputs, attr, params):
647647
new_input = []
648648
new_input.append(inputs.pop(0))
649649
new_input.append(inputs.pop(0))
650-
return AttrCvt(
651-
op_name="take",
652-
extras={'axis':axis},
653-
ignores=['Tindices', 'Tparams', 'validate_indices', \
654-
'Taxis', '_class'])(new_input, attr)
650+
return AttrCvt(op_name="take",
651+
extras={'axis': tvm.const(axis)},
652+
ignores=['Tindices', 'Tparams', 'validate_indices', \
653+
'Taxis', '_class'])(new_input, attr)
655654
return _impl
656655

657656
def _infer_out_shapes(inputs, params):
@@ -744,7 +743,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
744743
fshape_indices = None
745744
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
746745
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
747-
out = _op.strided_slice(inputs[0], begin=begin, end=end, stride=stride)
746+
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
748747
out_shape = _infer_out_shapes(out, params)[0]
749748
if not fshape_indices:
750749
fshape_indices = range(len(out_shape))
@@ -758,7 +757,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
758757
pass
759758
else:
760759
final_output.append(out_shape[gather_index])
761-
return _op.reshape(out, shape=tuple(final_output))
760+
return _op.reshape(out, newshape=tuple(final_output))
762761
return _impl
763762

764763
def _pad(name):
@@ -785,9 +784,12 @@ def _transpose():
785784
def _impl(inputs, attr, params):
786785
# If perm is not specified, axes is left empty,
787786
# otherwise its value is get from params
788-
param_name = inputs[1].name_hint
789-
axes = params.get(param_name, tvm.nd.array([])).asnumpy()
790-
return _op.transpose(inputs[0], axes=tuple(axes))
787+
param_name = _get_name_hint(inputs[1])
788+
if param_name in params:
789+
axes = tuple(params.get(param_name).asnumpy())
790+
else:
791+
axes = None
792+
return _op.transpose(inputs[0], axes=axes)
791793
return _impl
792794

793795
def _rank():
@@ -799,7 +801,7 @@ def _impl(inputs, attr, params):
799801
params[name] = tvm.nd.array([len(input_shapes[0])])
800802
return [_expr.var(name,
801803
shape=params[name].shape,
802-
dtype=params[name].dtype)]
804+
dtype='int32')]
803805

804806
return _impl
805807

@@ -813,20 +815,22 @@ def _impl(inputs, attr, params):
813815
params[name] = tvm.nd.array([start, limit, delta])
814816
return [_expr.var(name,
815817
shape=params[name].shape,
816-
dtype=params[name].dtype)]
818+
dtype='int32')]
817819
return _impl
818820

819821
def _elu():
820822
def _impl(inputs, attr, params):
821823
alpha = relay.const(-1.0, attr['T'].name)
822-
return alpha * _op.nn.relu(relay.const(1, attr['T'].name) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
824+
return alpha * _op.nn.relu(relay.const(1, attr['T'].name) \
825+
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
823826
return _impl
824827

825828
def _selu():
826829
def _impl(inputs, attr, params):
827830
alpha = relay.const(-1.6732632423543772848170429916717)
828831
gamma = relay.const(1.0507009873554804934193349852946)
829-
return gamma * (alpha * _op.nn.relu(relay.const(1, attr['T'].name) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
832+
return gamma * (alpha * _op.nn.relu(relay.const(1, attr['T'].name) \
833+
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
830834
return _impl
831835

832836
def _mean():
@@ -873,7 +877,7 @@ def _impl(inputs, attr, params):
873877
'MatMul' : _matmul(),
874878
'MaxPool' : _pooling('max_pool'),
875879
'Add' : _elemwise('add'),
876-
'Sub' : _elemwise('sub'),
880+
'Sub' : _elemwise('subtract'),
877881
'Mul' : _elemwise('multiply'),
878882
'Maximum' : _elemwise('max'),
879883
'Minimum' : _elemwise('min'),
@@ -971,10 +975,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
971975
raise NotImplementedError( \
972976
"The following operators are not implemented: {}".format(missing_operators))
973977

974-
final_op = None
975978
# Parse the nodes to re-create TF graph using Symbol API of NNVM
976979
for node in graph.node:
977-
print("Node: ", node.name, "Node Op:", node.op)
978980
# Tensorflow doesn't have seperate list for params extraction.
979981
# Operator name 'Const' is treated as a parameter to build NNVM params dict.
980982

@@ -1070,9 +1072,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
10701072
out = op
10711073
out = out[0] if len(out) == 1 else _expr.Tuple(out)
10721074
func = _expr.Function(ir_pass.free_vars(out), out)
1073-
print("OP:", op)
1074-
print("Func:", func)
1075-
print("Shape:", relay.ir_pass.infer_type(op[0]).checked_type)
10761075

10771076
return func, self._params
10781077

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -939,19 +939,20 @@ def test_forward_l2_normalize():
939939
# transpose
940940
# ---------
941941
def _test_forward_transpose(ishape, axes=None):
942-
input = np.random.uniform(size=ishape).astype(np.float32)
942+
data = np.random.uniform(size=ishape).astype(np.float32)
943943

944944
with tf.Graph().as_default():
945-
in1 = tf.placeholder(shape=input.shape, dtype=input.dtype, name="transpose_data")
945+
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name="transpose_data")
946946

947947
if axes is None:
948948
tf.transpose(in1)
949949
else:
950950
tf.transpose(in1, perm=axes)
951951

952-
compare_tf_with_tvm(input, 'transpose_data:0', 'transpose:0')
952+
compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')
953953

954954
def test_forward_transpose():
955+
_test_forward_transpose((2, 3, 4), (1, 2, 0))
955956
_test_forward_transpose((2, 3, 4))
956957
_test_forward_transpose((7, 8, 8, 10))
957958
_test_forward_transpose((2, 3, 4), (1, 2, 0))
@@ -1056,16 +1057,8 @@ def test_forward_rel_ops():
10561057
# Main
10571058
# ----
10581059
if __name__ == '__main__':
1059-
# NN
1060-
test_forward_convolution()
1061-
#test_forward_pooling()
1062-
#if tf.__version__ == '1.4.1':
1063-
# _test_forward_concat_v2()
1064-
#test_forward_lrn()
1065-
#test_forward_l2_normalize()
1066-
exit(0)
10671060
# Transforms
1068-
test_forward_transpose()
1061+
#test_forward_transpose()
10691062
test_forward_reshape()
10701063
test_forward_squeeze()
10711064
test_forward_pack()
@@ -1108,3 +1101,11 @@ def test_forward_rel_ops():
11081101

11091102
# Relational ops
11101103
test_forward_rel_ops()
1104+
1105+
# NN
1106+
#test_forward_convolution()
1107+
#test_forward_pooling()
1108+
#if tf.__version__ == '1.4.1':
1109+
# _test_forward_concat_v2()
1110+
#test_forward_lrn()
1111+
#test_forward_l2_normalize()

0 commit comments

Comments
 (0)