Skip to content

Commit d8d28bf

Browse files
authored
[Relay] Fix TFlite frontend for unpack, stridedslice (#10333)
We found this while converting an RNN model. The relay tflite frontend use squeeze at converting unpack, but when the unpack.axis=0, `None` is passed to relay.squeeze(), which would squeeze all dimensions with length 1, causing different results from TFLite. A possible fix might be, assign the unpack.axis as-is to relay.squeeze() As for stridedslice, when the tflite frontend handles shrink_axis_mask, the wrapped `begin` should be used, instead of the original one which can be negative. It can cause errors at https://github.com/apache/tvm/blob/d65ff6594d4d6db0062537a1d43c0504173b8e5c/include/tvm/topi/detail/strided_slice.h#L140 Related cases are also added to the python test.
1 parent 5a22c56 commit d8d28bf

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,7 +1667,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
16671667
if begin[index] < 0
16681668
else begin[index]
16691669
)
1670-
m_end[final_index] = begin[index] + 1
1670+
m_end[final_index] = m_begin[final_index] + 1
16711671
m_stride[final_index] = 1
16721672
fshape_indices.append(-2)
16731673
else:
@@ -2705,9 +2705,9 @@ def convert_unpack(self, op):
27052705
unpack_axis = unpack_options.Axis()
27062706

27072707
# Relay doesn't support 'unpack' operator so we use 'split' & 'squeeze' instead.
2708-
# We have to do 'squeeze' along the split axis but Relay expects
2709-
# squeeze_axis to be either None or List.
2710-
squeeze_axis = None if unpack_axis == 0 else [unpack_axis]
2708+
# We have to do 'squeeze' along the split axis.
2709+
# Relay expects squeeze_axis to be List.
2710+
squeeze_axis = [unpack_axis]
27112711

27122712
# Relay doesn't like TupleWrapper of 1 element so we isolate the case of unpacking
27132713
# a tensor by an axis with len(axis) == 1. For reference see convert_split().

tests/python/frontend/tflite/test_forward.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,9 @@ def test_forward_stridedslice():
618618
_test_stridedslice(
619619
(4, 4), [1, 0], [4, 4], [1, 1], "float32", shrink_axis_mask=2, quantized=quantized
620620
)
621+
_test_stridedslice(
622+
(3, 4), [-1, 0], [0, 3], [1, 1], "float32", shrink_axis_mask=1, quantized=quantized
623+
)
621624

622625

623626
#######################################################################
@@ -3186,6 +3189,9 @@ def test_forward_unpack():
31863189
"""UNPACK"""
31873190
_test_unpack(np.array(np.random.uniform(0, 5, (3, 1)), dtype=np.int32), axis=1, num_unpacks=1)
31883191
_test_unpack(np.array(np.random.uniform(0, 5, (3, 4)), dtype=np.float32), axis=0, num_unpacks=3)
3192+
_test_unpack(
3193+
np.array(np.random.uniform(0, 5, (3, 1, 2)), dtype=np.float32), axis=0, num_unpacks=3
3194+
)
31893195
# tflite 1.13 doesn't accept negative axis
31903196
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
31913197
_test_unpack(

0 commit comments

Comments
 (0)