forked from tensorflow/tflite-micro
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetadata_test.cc
81 lines (65 loc) · 2.93 KB
/
metadata_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Copyright 2024 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Test validity of the flatbuffer schema and illustrate use of the flatbuffer
// machinery with C++.
#include <iostream>
#include <vector>
#include "tensorflow/lite/micro/compression/metadata_saved.h"
#include "tensorflow/lite/micro/hexdump.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
using tflite::micro::compression::LutTensor;
using tflite::micro::compression::LutTensorT;
using tflite::micro::compression::Metadata;
using tflite::micro::compression::MetadataT;
using tflite::micro::compression::Subgraph;
using tflite::micro::compression::SubgraphT;
namespace {
bool operator==(const LutTensorT& a, const LutTensor& b) {
return a.tensor == b.tensor() && a.value_buffer == b.value_buffer() &&
a.index_bitwidth == b.index_bitwidth();
}
} // end anonymous namespace
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(ReadTest) {
// Create these objects on the stack and copy them into the subgraph's vector,
// so they can be compared later to what is read from the flatbuffer.
LutTensorT lut_tensor0;
lut_tensor0.tensor = 63;
lut_tensor0.value_buffer = 128;
lut_tensor0.index_bitwidth = 2;
LutTensorT lut_tensor1;
lut_tensor1.tensor = 64;
lut_tensor1.value_buffer = 129;
lut_tensor1.index_bitwidth = 4;
auto subgraph0 = std::make_unique<SubgraphT>();
subgraph0->lut_tensors.push_back(std::make_unique<LutTensorT>(lut_tensor0));
subgraph0->lut_tensors.push_back(std::make_unique<LutTensorT>(lut_tensor1));
auto metadata = std::make_unique<MetadataT>();
metadata->subgraphs.push_back(std::move(subgraph0));
flatbuffers::FlatBufferBuilder builder;
auto root = Metadata::Pack(builder, metadata.get());
builder.Finish(root);
const uint8_t* buffer = builder.GetBufferPointer();
const size_t buffer_size = builder.GetSize();
tflite::hexdump({reinterpret_cast<const std::byte*>(buffer), buffer_size});
std::cout << "length: " << buffer_size << "\n";
const Metadata* read_metadata =
tflite::micro::compression::GetMetadata(buffer);
const Subgraph* read_subgraph0 = read_metadata->subgraphs()->Get(0);
const LutTensor* read_lut_tensor0 = read_subgraph0->lut_tensors()->Get(0);
const LutTensor* read_lut_tensor1 = read_subgraph0->lut_tensors()->Get(1);
TF_LITE_MICRO_EXPECT(lut_tensor0 == *read_lut_tensor0);
TF_LITE_MICRO_EXPECT(lut_tensor1 == *read_lut_tensor1);
}
TF_LITE_MICRO_TESTS_END