Skip to content

Commit b7e7e8a

Browse files
committed
l2normalization operator support for tensorflow
1 parent 54ca149 commit b7e7e8a

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,21 @@ def _impl(inputs, attr, params):
482482
return AttrCvt(op_name='lrn')(new_inputs, attr_new)
483483
return _impl
484484

485+
def _sum():
486+
def _impl(inputs, attr, params):
487+
axis = params.pop(inputs[1].list_output_names()[0]).asnumpy()
488+
return AttrCvt(
489+
op_name='sum',
490+
extras={'axis': axis},
491+
transforms={'keep_dims':'keepdims'},
492+
ignores=['name', 'Tidx'])(inputs[0], attr)
493+
return _impl
494+
495+
def _square():
496+
def _impl(inputs, attr, params):
497+
return _sym.elemwise_mul(inputs[0], inputs[0])
498+
return _impl
499+
485500
def _gather_v2():
486501
"Tensorflow now support only gatherv2"
487502
def _impl(inputs, attr, params):
@@ -677,13 +692,17 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
677692
'Identity' : _identity(),
678693
'MatMul' : _matmul(),
679694
'MaxPool' : _pooling('max_pool'),
695+
'Add' : _elemwise('add'),
696+
'Sub' : _elemwise('sub'),
680697
'Mul' : _elemwise('mul'),
698+
'Maximum' : _elemwise('max'),
699+
'Minimum' : _elemwise('min'),
700+
'Sum' : _sum(),
701+
'Square' : _square(),
681702
'Relu' : AttrCvt('relu'),
682703
'Reshape' : _reshape(),
683704
'ResizeBilinear' : _resize_bilinear(),
684705
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
685-
'Sub' : _elemwise('sub'),
686-
'Add' : _elemwise('add'),
687706
'Rsqrt' : _rsqrt(),
688707
'Squeeze' : _squeeze(),
689708
'FusedBatchNorm' : _fused_batch_norm(),

nnvm/tests/python/frontend/tensorflow/test_forward.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorflow.python.framework import constant_op
1313
from tensorflow.python.framework import graph_util
1414
from tensorflow.python.ops import nn_ops
15+
from tensorflow.python.ops import nn
1516
from tensorflow.python.ops import array_ops
1617
from tensorflow.python.ops import gen_array_ops
1718
from tensorflow.python.ops import math_ops
@@ -878,7 +879,6 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
878879
sess,
879880
sess.graph.as_graph_def(add_shapes=True),
880881
['lrn'],)
881-
882882
tf_output = run_tf_graph(sess, inp_array, 'lrn0_data:0', 'lrn:0')
883883
tvm_output = run_tvm_graph(graph_def,
884884
inp_array,
@@ -889,6 +889,42 @@ def _test_lrn(ishape, size, axis, bias, alpha, beta):
889889
def test_forward_lrn():
890890
_test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
891891

892+
#######################################################################
893+
# l2_normalize
894+
# ------------
895+
def _test_l2_normalize(ishape, eps, axis):
896+
""" testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
897+
898+
inp_array = np.random.uniform(size=ishape).astype(np.float32)
899+
inp_array.fill(1)
900+
901+
with tf.Graph().as_default():
902+
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="Placeholder")
903+
nn.l2_normalize(in1,
904+
axis=axis,
905+
epsilon=eps,
906+
name=None,
907+
dim=None)
908+
909+
with tf.Session() as sess:
910+
graph_def = tf.graph_util.convert_variables_to_constants(
911+
sess,
912+
sess.graph.as_graph_def(add_shapes=True),
913+
['l2_normalize'],
914+
)
915+
tf_output = run_tf_graph(sess, inp_array, 'Placeholder:0', 'Placeholder:0')
916+
tvm_output = run_tvm_graph(graph_def,
917+
inp_array,
918+
"Placeholder",
919+
tf_output.shape,
920+
tf_output.dtype)
921+
922+
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
923+
sess.close()
924+
def test_forward_l2_normalize():
925+
_test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
926+
927+
#######################################################################
892928
# Main
893929
# ----
894930
if __name__ == '__main__':
@@ -910,3 +946,4 @@ def test_forward_lrn():
910946
test_forward_gather()
911947
test_forward_ptb()
912948
test_forward_lrn()
949+
test_forward_l2_normalize()

0 commit comments

Comments
 (0)