Skip to content

Commit

Permalink
Fix TensorListStack so it properly handles the case where list has no…
Browse files Browse the repository at this point in the history
… elements set and element shape has a zero somewhere in it, but it is not a scalar.

PiperOrigin-RevId: 546449625
  • Loading branch information
LukeBoyer authored and tensorflower-gardener committed Jul 8, 2023
1 parent 14df49d commit c05598a
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 32 deletions.
30 changes: 5 additions & 25 deletions tensorflow/lite/kernels/variants/list_kernels/list_stack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstring>
#include <limits>
#include <utility>

#include "tensorflow/lite/array.h"
Expand Down Expand Up @@ -76,6 +75,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_TYPES_EQ(context, output->type, arr->ElementType());

IntArrayUniquePtr cur_shape_suffix;

// If succeeds and result not nullptr, guaranteed to be fully defined.
TF_LITE_ENSURE_OK(context, GetShapeIfAllEqual(*arr, cur_shape_suffix));

Expand Down Expand Up @@ -106,38 +106,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cur_shape_suffix->size * sizeof(int));
final_output_shape->data[0] = arr->NumElements();

// Length zero will result in a tensor with empty allocation, so clear
// data just in case and short circuit.
if (arr->NumElements() == 0) {
TfLiteTensorDataFree(output);
if (output->dims) {
TfLiteIntArrayFree(output->dims);
}
output->dims = final_output_shape.release();
output->bytes = 0;
return kTfLiteOk;
}

} else {
final_output_shape = BuildTfLiteArray({arr->NumElements()});
}

context->ResizeTensor(context, output, final_output_shape.release());

int num_elements = 1;
for (int i = 0; i < output->dims->size; ++i) {
const int d = output->dims->data[i];
if (d > 0) {
// Check overflow.
TF_LITE_ENSURE(context,
num_elements < std::numeric_limits<int>().max() / d);
num_elements *= d;
}
const int num_elements = NumElements(output);
if (num_elements == 0) {
TfLiteTensorDataFree(output);
return kTfLiteOk;
}

TF_LITE_ENSURE_EQ(context, output->bytes,
num_elements * TfLiteTypeGetSize(output->type));

// This has to be an int and we would have returned already if divisor == 0.
const int element_num_elements = num_elements / output->dims->data[0];
const size_t bytes_per_element =
Expand Down
64 changes: 63 additions & 1 deletion tensorflow/lite/kernels/variants/list_kernels/list_stack_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,19 @@ class ListStackModel : public ListOpModel {
SetCustomOp("ListStack", {}, Register_LIST_STACK);
BuildInterpreter({{}, {1}});
}

ListStackModel(TensorData output_data, TensorData shape_input_data) {
tensor_id_ = AddOutput(output_data);
list_id_ = AddInput({TensorType_VARIANT, {}});
shape_id_ = AddInput(shape_input_data);
SetCustomOp("ListStack", {}, Register_LIST_STACK);
BuildInterpreter({{}, shape_input_data.shape});
}

const TfLiteTensor* GetOutputTensor(int tensor_id) {
return interpreter_->tensor(tensor_id);
}

int tensor_id_;
int shape_id_;
int list_id_;
Expand Down Expand Up @@ -248,7 +258,7 @@ TEST(ListStackTest, MismatchedOutput_ReturnsResizedOutput1D) {
}

TEST(ListStackTest, MismatchedOutput_ReturnsResizedOutput2D) {
ListStackModel m({TensorType_INT32, {}});
ListStackModel m({TensorType_INT32, std::vector<int>{}});

m.PopulateListTensor(m.list_id_, {}, 2, kTfLiteInt32);
m.PopulateTensor<int>(m.shape_id_, {2});
Expand All @@ -259,5 +269,57 @@ TEST(ListStackTest, MismatchedOutput_ReturnsResizedOutput2D) {
EXPECT_THAT(output, DimsAre({2, 2}));
}

TEST(ListStackTest, Trailing0DimInElementShape1D_NonZeroLen_Returns2DNoData) {
ListStackModel m({TensorType_INT32, std::vector<int>{}});

m.PopulateListTensor(m.list_id_, {}, 2, kTfLiteInt32);
m.PopulateTensor<int>(m.shape_id_, {0});

ASSERT_EQ(m.Invoke(), kTfLiteOk);
const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_);

ASSERT_THAT(output, DimsAre({2, 0}));
EXPECT_EQ(output->bytes, 0);
}

TEST(ListStackTest, Trailing0DimInElementShape2D_NonZeroLen_Returns3DNoData) {
ListStackModel m({TensorType_INT32, {}}, {TensorType_INT32, {2}});

m.PopulateListTensor(m.list_id_, {}, 2, kTfLiteInt32);
m.PopulateTensor<int>(m.shape_id_, {2, 0});

ASSERT_EQ(m.Invoke(), kTfLiteOk);
const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_);

ASSERT_THAT(output, DimsAre({2, 2, 0}));
EXPECT_EQ(output->bytes, 0);
}

TEST(ListStackTest, Trailing0DimInElementShape1D_ZeroLen_Returns2DNoData) {
ListStackModel m({TensorType_INT32, {}}, {TensorType_INT32, {1}});

m.PopulateListTensor(m.list_id_, {}, 0, kTfLiteInt32);
m.PopulateTensor<int>(m.shape_id_, {0});

ASSERT_EQ(m.Invoke(), kTfLiteOk);
const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_);

ASSERT_THAT(output, DimsAre({0, 0}));
EXPECT_EQ(output->bytes, 0);
}

TEST(ListStackTest, Trailing0DimInElementShape2D_ZeroLen_Returns3DNoData) {
ListStackModel m({TensorType_INT32, {}}, {TensorType_INT32, {2}});

m.PopulateListTensor(m.list_id_, {}, 0, kTfLiteInt32);
m.PopulateTensor<int>(m.shape_id_, {2, 0});

ASSERT_EQ(m.Invoke(), kTfLiteOk);
const TfLiteTensor* output = m.GetOutputTensor(m.tensor_id_);

ASSERT_THAT(output, DimsAre({0, 2, 0}));
EXPECT_EQ(output->bytes, 0);
}

} // namespace
} // namespace tflite
10 changes: 8 additions & 2 deletions tensorflow/lite/kernels/variants/list_ops_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,16 @@ IntArrayUniquePtr TensorAsShape(const TfLiteTensor& shape) {
}

IntArrayUniquePtr MergeShapesOrNull(IntArrayUniquePtr l, IntArrayUniquePtr r) {
if (l == nullptr || l->size == 0) {
if (l == nullptr) {
return r;
}
if (r == nullptr || r->size == 0) {
if (r == nullptr) {
return l;
}
if (l->size == 0) {
return r;
}
if (r->size == 0) {
return l;
}
if (l->size != r->size) {
Expand Down
15 changes: 11 additions & 4 deletions tensorflow/lite/kernels/variants/list_ops_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,20 @@ TEST(MergeShapesOrNull, UnrankedAndRankedUnknown_ReturnsRankedUnknown) {
}

TEST(MergeShapesOrNull, NullInput_ReturnsOther) {
IntArrayUniquePtr l = BuildTfLiteArray({3});
IntArrayUniquePtr r = BuildTfLiteArray({2});
EXPECT_THAT(MergeShapesOrNull(std::move(l), nullptr).get(), DimsAre({3}));
EXPECT_THAT(MergeShapesOrNull(nullptr, std::move(r)).get(), DimsAre({2}));
EXPECT_THAT(MergeShapesOrNull(BuildTfLiteArray({3}), nullptr).get(),
DimsAre({3}));
EXPECT_THAT(MergeShapesOrNull(nullptr, BuildTfLiteArray({2})).get(),
DimsAre({2}));
EXPECT_EQ(MergeShapesOrNull(nullptr, nullptr).get(), nullptr);
}

TEST(MergeShapesOrNull, NullInput_ReturnsUnrankedOther) {
EXPECT_THAT(MergeShapesOrNull(BuildTfLiteArray({}), nullptr).get(),
DimsAre({}));
EXPECT_THAT(MergeShapesOrNull(nullptr, BuildTfLiteArray({})).get(),
DimsAre({}));
}

TEST(ElementsSameShape, NoElements_SucceedsWithNullptr) {
TensorArray arr = {kTfLiteInt32, BuildTfLiteArray({})};
arr.Resize(2);
Expand Down
49 changes: 49 additions & 0 deletions tensorflow/lite/kernels/variants/py/end_to_end_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,55 @@ def test_register_list_ops_and_invoke_dynamic_shape(
self.assertEqual(tf_out.shape, output_tensor.shape)
self.assertTrue((tf_out == output_tensor).numpy().all())

@parameterized.named_parameters(
("ZeroElements_ScalarStackShape", [], 0),
("NonZeroElements_ScalarStackShape", [], 2),
("NonZeroElements_ZeroStackShape", [0], 2),
("ZeroElements_ZeroStackShape", [0], 0),
("ZeroElements_2DZeroStackShape", [0, 2], 0),
("NonZeroElements_2DZeroStackShape", [0, 2], 2),
)
def test_stack_empty_list(
self, stack_element_shape: list[int], num_elements: int
):
@tf.function(
input_signature=[
tf.TensorSpec(shape=tf.TensorShape(None), dtype=tf.int32)
]
)
def reserve_stack(stack_element_shape) -> tf.Tensor:
l = list_ops.tensor_list_reserve(
element_shape=tf.TensorShape(None),
element_dtype=tf.float32,
num_elements=num_elements,
)
return list_ops.tensor_list_stack(
l, element_shape=stack_element_shape, element_dtype=tf.float32
)

interpreter = self._get_interpreter_from_c_func(reserve_stack)

input_index = interpreter.get_input_details()[0]["index"]

interpreter.resize_tensor_input(input_index, [len(stack_element_shape)])

interpreter.allocate_tensors()

input_tensor = np.array(stack_element_shape, dtype=np.int32)
interpreter.set_tensor(input_index, input_tensor)

interpreter.invoke()

output_tensor = interpreter.get_tensor(
interpreter.get_output_details()[0]["index"]
)

tf_out = reserve_stack(input_tensor)

self.assertEqual(tf_out.dtype, output_tensor.dtype)
self.assertEqual(tf_out.shape, output_tensor.shape)
self.assertTrue((tf_out == output_tensor).numpy().all())


if __name__ == "__main__":
googletest.main()

0 comments on commit c05598a

Please sign in to comment.