Skip to content

Commit 0dcbdb3

Browse files
committed
l2normalization operator support for tensorflow
1 parent 54f5adc commit 0dcbdb3

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

nnvm/python/nnvm/frontend/tensorflow.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,21 @@ def _impl(inputs, attr, params):
468468
ignores=['index_type', 'T'])(new_inputs, attr)
469469
return _impl
470470

471+
def _sum():
472+
def _impl(inputs, attr, params):
473+
axis = params.pop(inputs[1].list_output_names()[0]).asnumpy()
474+
return AttrCvt(
475+
op_name='sum',
476+
extras={'axis': axis},
477+
transforms={'keep_dims':'keepdims'},
478+
ignores=['name', 'Tidx'])(inputs[0], attr)
479+
return _impl
480+
481+
def _square():
482+
def _impl(inputs, attr, params):
483+
return _sym.elemwise_mul(inputs[0], inputs[0])
484+
return _impl
485+
471486
def _gather_v2():
472487
"Tensorflow now support only gatherv2"
473488
def _impl(inputs, attr, params):
@@ -663,13 +678,17 @@ def _impl(inputs, in_state_c, in_state_h, attr, params):
663678
'Identity' : _identity(),
664679
'MatMul' : _matmul(),
665680
'MaxPool' : _pooling('max_pool'),
681+
'Add' : _elemwise('add'),
682+
'Sub' : _elemwise('sub'),
666683
'Mul' : _elemwise('mul'),
684+
'Maximum' : _elemwise('max'),
685+
'Minimum' : _elemwise('min'),
686+
'Sum' : _sum(),
687+
'Square' : _square(),
667688
'Relu' : AttrCvt('relu'),
668689
'Reshape' : _reshape(),
669690
'ResizeBilinear' : _resize_bilinear(),
670691
'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}),
671-
'Sub' : _elemwise('sub'),
672-
'Add' : _elemwise('add'),
673692
'Rsqrt' : _rsqrt(),
674693
'Squeeze' : _squeeze(),
675694
'FusedBatchNorm' : _fused_batch_norm(),

nnvm/tests/python/frontend/tensorflow/test_forward.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorflow.python.framework import constant_op
1313
from tensorflow.python.framework import graph_util
1414
from tensorflow.python.ops import nn_ops
15+
from tensorflow.python.ops import nn
1516
from tensorflow.python.ops import array_ops
1617
from tensorflow.python.ops import gen_array_ops
1718
from tensorflow.python.ops import math_ops
@@ -854,6 +855,41 @@ def _get_sample(data, state):
854855
np.testing.assert_allclose(tf_samples, tvm_samples, rtol=1e-5, atol=1e-5)
855856
assert(tvm_sample_str == tf_sample_str)
856857

858+
#######################################################################
859+
# l2_normalize
860+
# ------------
861+
def _test_l2_normalize(ishape, eps, axis):
862+
""" testing l2 normalize (uses max, sum, square, sqrt frontend operators)"""
863+
864+
inp_array = np.random.uniform(size=ishape).astype(np.float32)
865+
inp_array.fill(1)
866+
867+
with tf.Graph().as_default():
868+
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="Placeholder")
869+
nn.l2_normalize(in1,
870+
axis=axis,
871+
epsilon=eps,
872+
name=None,
873+
dim=None)
874+
875+
with tf.Session() as sess:
876+
graph_def = tf.graph_util.convert_variables_to_constants(
877+
sess,
878+
sess.graph.as_graph_def(add_shapes=True),
879+
['l2_normalize'],
880+
)
881+
tf_output = run_tf_graph(sess, inp_array, 'Placeholder:0', 'Placeholder:0')
882+
tvm_output = run_tvm_graph(graph_def,
883+
inp_array,
884+
"Placeholder",
885+
tf_output.shape,
886+
tf_output.dtype)
887+
888+
np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3)
889+
sess.close()
890+
def test_forward_l2_normalize():
891+
_test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
892+
857893
#######################################################################
858894
# Main
859895
# ----
@@ -875,3 +911,4 @@ def _get_sample(data, state):
875911
test_forward_stridedslice()
876912
test_forward_gather()
877913
test_forward_ptb()
914+
test_forward_l2_normalize()

0 commit comments

Comments
 (0)