Skip to content

Commit

Permalink
Add parser support for ReLU tflite operator (apache#4022)
Browse files Browse the repository at this point in the history
  • Loading branch information
inadob authored and wweic committed Sep 30, 2019
1 parent 7f69a7d commit b866c7d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab):
'PACK': self.convert_pack,
'LOGISTIC': self.convert_logistic,
'TANH':self.convert_tanh,
'RELU':self.convert_relu,
'SPLIT': self.convert_split,
'TRANSPOSE': self.convert_transpose,
'TILE': self.convert_tile,
Expand Down Expand Up @@ -345,6 +346,23 @@ def convert_tanh(self, op):

return out

def convert_relu(self, op):
"""Convert TFLite ReLU"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.nn.relu(in_expr)

return out

def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
Expand Down
16 changes: 16 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,21 @@ def test_forward_tanh():
""" TANH """
_test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

#######################################################################
# ReLu
# --------

def _test_relu(data):
""" One iteration of ReLU """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_ops.relu(in_data)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_relu():
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))

#######################################################################
# Fully Connected
# -------
Expand Down Expand Up @@ -999,6 +1014,7 @@ def test_forward_ssd_mobilenet_v1():
test_forward_pooling()
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
test_forward_fully_connected()

# Elemwise
Expand Down

0 comments on commit b866c7d

Please sign in to comment.