diff --git a/tensorflow/lite/kernels/variants/list_kernels/list_stack.cc b/tensorflow/lite/kernels/variants/list_kernels/list_stack.cc index cd4ca10452a0a5..a0f5aceec8a3e8 100644 --- a/tensorflow/lite/kernels/variants/list_kernels/list_stack.cc +++ b/tensorflow/lite/kernels/variants/list_kernels/list_stack.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include #include #include "tensorflow/lite/array.h" @@ -76,6 +75,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_TYPES_EQ(context, output->type, arr->ElementType()); IntArrayUniquePtr cur_shape_suffix; + // If succeeds and result not nullptr, guaranteed to be fully defined. TF_LITE_ENSURE_OK(context, GetShapeIfAllEqual(*arr, cur_shape_suffix)); @@ -106,38 +106,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { cur_shape_suffix->size * sizeof(int)); final_output_shape->data[0] = arr->NumElements(); - // Length zero will result in a tensor with empty allocation, so clear - // data just in case and short circuit. - if (arr->NumElements() == 0) { - TfLiteTensorDataFree(output); - if (output->dims) { - TfLiteIntArrayFree(output->dims); - } - output->dims = final_output_shape.release(); - output->bytes = 0; - return kTfLiteOk; - } - } else { final_output_shape = BuildTfLiteArray({arr->NumElements()}); } context->ResizeTensor(context, output, final_output_shape.release()); - int num_elements = 1; - for (int i = 0; i < output->dims->size; ++i) { - const int d = output->dims->data[i]; - if (d > 0) { - // Check overflow. - TF_LITE_ENSURE(context, - num_elements < std::numeric_limits().max() / d); - num_elements *= d; - } + const int num_elements = NumElements(output); + if (num_elements == 0) { + TfLiteTensorDataFree(output); + return kTfLiteOk; } - TF_LITE_ENSURE_EQ(context, output->bytes, - num_elements * TfLiteTypeGetSize(output->type)); - // This has to be an int and we would have returned already if divisor == 0. const int element_num_elements = num_elements / output->dims->data[0]; const size_t bytes_per_element = diff --git a/tensorflow/lite/kernels/variants/list_kernels/list_stack_test.cc b/tensorflow/lite/kernels/variants/list_kernels/list_stack_test.cc index 412072e7aab8ee..6269dc77bf2578 100644 --- a/tensorflow/lite/kernels/variants/list_kernels/list_stack_test.cc +++ b/tensorflow/lite/kernels/variants/list_kernels/list_stack_test.cc @@ -39,9 +39,19 @@ class ListStackModel : public ListOpModel { SetCustomOp("ListStack", {}, Register_LIST_STACK); BuildInterpreter({{}, {1}}); } + + ListStackModel(TensorData output_data, TensorData shape_input_data) { + tensor_id_ = AddOutput(output_data); + list_id_ = AddInput({TensorType_VARIANT, {}}); + shape_id_ = AddInput(shape_input_data); + SetCustomOp("ListStack", {}, Register_LIST_STACK); + BuildInterpreter({{}, shape_input_data.shape}); + } + const TfLiteTensor* GetOutputTensor(int tensor_id) { return interpreter_->tensor(tensor_id); } + int tensor_id_; int shape_id_; int list_id_; @@ -248,7 +258,7 @@ TEST(ListStackTest, MismatchedOutput_ReturnsResizedOutput1D) { } TEST(ListStackTest, MismatchedOutput_ReturnsResizedOutput2D) { - ListStackModel m({TensorType_INT32, {}}); + ListStackModel m({TensorType_INT32, std::vector{}}); m.PopulateListTensor(m.list_id_, {}, 2, kTfLiteInt32); m.PopulateTensor(m.shape_id_, {2}); @@ -259,5 +269,57 @@ TEST(ListStackTest, MismatchedOutput_ReturnsResizedOutput2D) { EXPECT_THAT(output, DimsAre({2, 2})); } +TEST(ListStackTest, Trailing0DimInElementShape1D_NonZeroLen_Returns2DNoData) { + ListStackModel m({TensorType_INT32, std::vector{}}); + + m.PopulateListTensor(m.list_id_, {}, 2, kTfLiteInt32); + m.PopulateTensor(m.shape_id_, {0}); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_); + + ASSERT_THAT(output, DimsAre({2, 0})); + EXPECT_EQ(output->bytes, 0); +} + +TEST(ListStackTest, Trailing0DimInElementShape2D_NonZeroLen_Returns3DNoData) { + ListStackModel m({TensorType_INT32, {}}, {TensorType_INT32, {2}}); + + m.PopulateListTensor(m.list_id_, {}, 2, kTfLiteInt32); + m.PopulateTensor(m.shape_id_, {2, 0}); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_); + + ASSERT_THAT(output, DimsAre({2, 2, 0})); + EXPECT_EQ(output->bytes, 0); +} + +TEST(ListStackTest, Trailing0DimInElementShape1D_ZeroLen_Returns2DNoData) { + ListStackModel m({TensorType_INT32, {}}, {TensorType_INT32, {1}}); + + m.PopulateListTensor(m.list_id_, {}, 0, kTfLiteInt32); + m.PopulateTensor(m.shape_id_, {0}); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_); + + ASSERT_THAT(output, DimsAre({0, 0})); + EXPECT_EQ(output->bytes, 0); +} + +TEST(ListStackTest, Trailing0DimInElementShape2D_ZeroLen_Returns3DNoData) { + ListStackModel m({TensorType_INT32, {}}, {TensorType_INT32, {2}}); + + m.PopulateListTensor(m.list_id_, {}, 0, kTfLiteInt32); + m.PopulateTensor(m.shape_id_, {2, 0}); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_); + + ASSERT_THAT(output, DimsAre({0, 2, 0})); + EXPECT_EQ(output->bytes, 0); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/variants/list_ops_util.cc b/tensorflow/lite/kernels/variants/list_ops_util.cc index 6200c32b39112b..447e09e952669d 100644 --- a/tensorflow/lite/kernels/variants/list_ops_util.cc +++ b/tensorflow/lite/kernels/variants/list_ops_util.cc @@ -36,10 +36,16 @@ IntArrayUniquePtr TensorAsShape(const TfLiteTensor& shape) { } IntArrayUniquePtr MergeShapesOrNull(IntArrayUniquePtr l, IntArrayUniquePtr r) { - if (l == nullptr || l->size == 0) { + if (l == nullptr) { return r; } - if (r == nullptr || r->size == 0) { + if (r == nullptr) { + return l; + } + if (l->size == 0) { + return r; + } + if (r->size == 0) { return l; } if (l->size != r->size) { diff --git a/tensorflow/lite/kernels/variants/list_ops_util_test.cc b/tensorflow/lite/kernels/variants/list_ops_util_test.cc index d9212236e4a36e..cee1e9a9b6c408 100644 --- a/tensorflow/lite/kernels/variants/list_ops_util_test.cc +++ b/tensorflow/lite/kernels/variants/list_ops_util_test.cc @@ -128,13 +128,20 @@ TEST(MergeShapesOrNull, UnrankedAndRankedUnknown_ReturnsRankedUnknown) { } TEST(MergeShapesOrNull, NullInput_ReturnsOther) { - IntArrayUniquePtr l = BuildTfLiteArray({3}); - IntArrayUniquePtr r = BuildTfLiteArray({2}); - EXPECT_THAT(MergeShapesOrNull(std::move(l), nullptr).get(), DimsAre({3})); - EXPECT_THAT(MergeShapesOrNull(nullptr, std::move(r)).get(), DimsAre({2})); + EXPECT_THAT(MergeShapesOrNull(BuildTfLiteArray({3}), nullptr).get(), + DimsAre({3})); + EXPECT_THAT(MergeShapesOrNull(nullptr, BuildTfLiteArray({2})).get(), + DimsAre({2})); EXPECT_EQ(MergeShapesOrNull(nullptr, nullptr).get(), nullptr); } +TEST(MergeShapesOrNull, NullInput_ReturnsUnrankedOther) { + EXPECT_THAT(MergeShapesOrNull(BuildTfLiteArray({}), nullptr).get(), + DimsAre({})); + EXPECT_THAT(MergeShapesOrNull(nullptr, BuildTfLiteArray({})).get(), + DimsAre({})); +} + TEST(ElementsSameShape, NoElements_SucceedsWithNullptr) { TensorArray arr = {kTfLiteInt32, BuildTfLiteArray({})}; arr.Resize(2); diff --git a/tensorflow/lite/kernels/variants/py/end_to_end_test.py b/tensorflow/lite/kernels/variants/py/end_to_end_test.py index deaf25ee8946d1..dc690c598027cf 100644 --- a/tensorflow/lite/kernels/variants/py/end_to_end_test.py +++ b/tensorflow/lite/kernels/variants/py/end_to_end_test.py @@ -144,6 +144,55 @@ def test_register_list_ops_and_invoke_dynamic_shape( self.assertEqual(tf_out.shape, output_tensor.shape) self.assertTrue((tf_out == output_tensor).numpy().all()) + @parameterized.named_parameters( + ("ZeroElements_ScalarStackShape", [], 0), + ("NonZeroElements_ScalarStackShape", [], 2), + ("NonZeroElements_ZeroStackShape", [0], 2), + ("ZeroElements_ZeroStackShape", [0], 0), + ("ZeroElements_2DZeroStackShape", [0, 2], 0), + ("NonZeroElements_2DZeroStackShape", [0, 2], 2), + ) + def test_stack_empty_list( + self, stack_element_shape: list[int], num_elements: int + ): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=tf.TensorShape(None), dtype=tf.int32) + ] + ) + def reserve_stack(stack_element_shape) -> tf.Tensor: + l = list_ops.tensor_list_reserve( + element_shape=tf.TensorShape(None), + element_dtype=tf.float32, + num_elements=num_elements, + ) + return list_ops.tensor_list_stack( + l, element_shape=stack_element_shape, element_dtype=tf.float32 + ) + + interpreter = self._get_interpreter_from_c_func(reserve_stack) + + input_index = interpreter.get_input_details()[0]["index"] + + interpreter.resize_tensor_input(input_index, [len(stack_element_shape)]) + + interpreter.allocate_tensors() + + input_tensor = np.array(stack_element_shape, dtype=np.int32) + interpreter.set_tensor(input_index, input_tensor) + + interpreter.invoke() + + output_tensor = interpreter.get_tensor( + interpreter.get_output_details()[0]["index"] + ) + + tf_out = reserve_stack(input_tensor) + + self.assertEqual(tf_out.dtype, output_tensor.dtype) + self.assertEqual(tf_out.shape, output_tensor.shape) + self.assertTrue((tf_out == output_tensor).numpy().all()) + if __name__ == "__main__": googletest.main()