Skip to content

Commit

Permalink
working decompression unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 18, 2024
1 parent 81ecf2e commit efedcc2
Showing 1 changed file with 77 additions and 42 deletions.
119 changes: 77 additions & 42 deletions tensorflow/lite/micro/kernels/decompress_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/decompress.h"

#include <algorithm>
#include <initializer_list>
#include <type_traits>

#include "tensorflow/lite/c/builtin_op_data.h"
Expand Down Expand Up @@ -229,7 +230,29 @@ void FillCompressed(uint8_t* compressed, const size_t total_golden_elements,
}

template <typename T>
void TestDecompression(TestingInfo<T>* info) {
void GenerateData(TestingInfo<T>& info) {
FillValueTable(info.total_value_table_elements, info.value_table);
FillGoldens(info.total_elements, info.goldens,
info.total_value_table_elements, info.value_table,
info.channel_count, info.use_alt_axis);
FillCompressed(info.compressed, info.total_elements, info.goldens,
info.total_value_table_elements / info.channel_count,
info.value_table, info.channel_count, info.use_alt_axis,
info.bit_width);
}

template <typename T>
void TestDataSetup(TestingInfo<T>* info, TestingData<T>* data) {
info->output = data->output;
info->goldens = data->goldens;
info->compressed = data->compressed;
info->value_table = data->value_table;
}

template <typename T>
TfLiteStatus TestDecompression(TestingInfo<T>* info) {
GenerateData(*info);

CompressionTensorData ctd = {};
LookupTableData lut_data = {};
ctd.scheme = CompressionScheme::kBinQuant;
Expand All @@ -252,84 +275,96 @@ void TestDecompression(TestingInfo<T>* info) {
for (size_t i = 0; i < info->total_elements; i++) {
TF_LITE_MICRO_EXPECT_EQ(info->goldens[i], info->output[i]);
if (micro_test::did_test_fail) {
return;
return kTfLiteError;
}
}
micro_test::did_test_fail = saved_fail_state;
return kTfLiteOk;
}

template <typename T>
void GenerateData(TestingInfo<T>& info) {
FillValueTable(info.total_value_table_elements, info.value_table);
FillGoldens(info.total_elements, info.goldens,
info.total_value_table_elements, info.value_table,
info.channel_count, info.use_alt_axis);
FillCompressed(info.compressed, info.total_elements, info.goldens,
info.total_value_table_elements / info.channel_count,
info.value_table, info.channel_count, info.use_alt_axis,
info.bit_width);
}

template <typename T>
void TestDataSetup(TestingInfo<T>* info, TestingData<T>* data) {
info->output = data->output;
info->goldens = data->goldens;
info->compressed = data->compressed;
info->value_table = data->value_table;
}

template <typename T>
void TestValueTable2n(TestingInfo<T>& info) {
info.total_elements = 16;
TfLiteStatus TestValueTable2n(TestingInfo<T>& info) {
if (std::is_same<T, bool>::value) {
info.total_value_table_elements = 2 * info.channel_count;
} else {
info.total_value_table_elements =
(1 << info.bit_width) * info.channel_count;
info.total_value_table_elements =
std::min(info.total_value_table_elements, info.total_elements);
}
info.total_value_table_elements =
std::min(info.total_value_table_elements, info.total_elements);
info.total_value_table_elements = std::min(info.total_value_table_elements,
TestingData<T>::kValueTableSize);

MicroPrintf(" Testing value table 2^n: %d",
MicroPrintf(" Testing value table 2^n: %d",
info.total_value_table_elements);
GenerateData(info);
TestDecompression(&info);
return TestDecompression(&info);
}

template <typename T>
void TestValueTable2nMinus1(TestingInfo<T>& info) {
info.total_elements = 16;
TfLiteStatus TestValueTable2nMinus1(TestingInfo<T>& info) {
if (std::is_same<T, bool>::value) {
info.total_value_table_elements = 1 * info.channel_count;
} else {
info.total_value_table_elements =
((1 << info.bit_width) - 1) * info.channel_count;
info.total_value_table_elements =
std::min(info.total_value_table_elements, info.total_elements);
}
info.total_value_table_elements =
std::min(info.total_value_table_elements, info.total_elements);
info.total_value_table_elements = std::min(info.total_value_table_elements,
TestingData<T>::kValueTableSize);

MicroPrintf(" Testing value table 2^n-1: %d",
MicroPrintf(" Testing value table 2^n-1: %d",
info.total_value_table_elements);
GenerateData(info);
TestDecompression(&info);
return TestDecompression(&info);
}

template <typename T>
void TestElementCount(TestingInfo<T>& info) {
static constexpr std::initializer_list<size_t> elements_per_channel{
1, 2,
3, 4,
5, 7,
8, 9,
15, 16,
17, 31,
32, 33,
63, 64,
65, 127,
128, 129,
255, TestingData<T>::kElementsPerChannel};

MicroPrintf(" Testing element count: %d thru %d",
elements_per_channel.begin()[0], elements_per_channel.end()[-1]);

for (size_t i = 0; i < elements_per_channel.size(); i++) {
info.total_elements = elements_per_channel.begin()[i] * info.channel_count;

TfLiteStatus s;
s = TestValueTable2n(info);
if (s == kTfLiteError) {
MicroPrintf(" Failed element count: %d", info.total_elements);
}
s = TestValueTable2nMinus1(info);
if (s == kTfLiteError) {
MicroPrintf(" Failed element count: %d", info.total_elements);
}
}
}

template <typename T>
void TestSingleChannel(TestingInfo<T>& info) {
info.channel_count = 1;

MicroPrintf(" Testing single channel");
TestValueTable2n(info);
TestValueTable2nMinus1(info);
TestElementCount(info);
}

template <typename T>
void TestMultiChannel(TestingInfo<T>& info) {
info.channel_count = 2;
info.channel_count = TestingData<T>::kChannels;

MicroPrintf(" Testing multiple channels: %d", info.channel_count);
TestValueTable2n(info);
TestValueTable2nMinus1(info);
TestElementCount(info);
}

template <typename T>
Expand All @@ -355,7 +390,7 @@ void TestAllBitWidths() {
TestingInfo<T> info = {};
TestDataSetup<T>(&info, GetTestingData<T>());

for (size_t bw = 1; bw <= 7; bw++) {
for (size_t bw = 1; bw <= TestingData<T>::kBitWidth; bw++) {
info.bit_width = bw;

TestBitWidth<T>(info);
Expand Down

0 comments on commit efedcc2

Please sign in to comment.