Skip to content

Commit 31ca96c

Browse files
Maratyszczabddppq
authored andcommitted
Microbenchmark for encoding+decoding ModelProto and GraphProto with a single operator (onnx#609)
To build: cmake -GNinja -DCMAKE_BUILD_TYPE=Release -DONNX_BUILD_BENCHMARKS=ON -DBUILD_PYTHON=OFF ..
1 parent 79dc46f commit 31ca96c

File tree

4 files changed

+164
-0
lines changed

4 files changed

+164
-0
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22
path = third_party/pybind11
33
url = https://github.com/pybind/pybind11.git
44
branch = master
5+
[submodule "third_party/benchmark"]
6+
path = third_party/benchmark
7+
url = https://github.com/google/benchmark.git

CMakeLists.txt

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.1)
33

44
# Project
55
project(onnx C CXX)
6+
option(ONNX_BUILD_BENCHMARKS "Build ONNX micro-benchmarks" OFF)
67

78
option(BUILD_PYTHON "Build Python binaries" ON)
89

@@ -213,6 +214,19 @@ if(BUILD_PYTHON)
213214
endif()
214215
endif()
215216

217+
if(ONNX_BUILD_BENCHMARKS)
218+
if(NOT TARGET benchmark)
219+
# We will not need to test benchmark lib itself.
220+
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark testing as we don't need it.")
221+
# We will not need to install benchmark since we link it statically.
222+
set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable benchmark install to avoid overwriting vendor install.")
223+
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/benchmark)
224+
endif()
225+
226+
add_executable(protobuf-bench tools/protobuf-bench.cc)
227+
target_link_libraries(protobuf-bench onnx_proto benchmark)
228+
endif()
229+
216230
# Export include directories
217231
set(ONNX_INCLUDE_DIRS "${ONNX_ROOT}" "${CMAKE_CURRENT_BINARY_DIR}")
218232
set(ONNX_INCLUDE_DIRS ${ONNX_INCLUDE_DIRS} PARENT_SCOPE)

third_party/benchmark

Submodule benchmark added at 491360b

tools/protobuf-bench.cc

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#include <benchmark/benchmark.h>
2+
3+
#include <onnx/onnx.pb.h>
4+
5+
using namespace ONNX_NAMESPACE;
6+
7+
8+
inline void createValueInfo4D(
9+
ValueInfoProto& value_info,
10+
const std::string& name,
11+
int64_t n,
12+
int64_t c,
13+
int64_t h,
14+
int64_t w) {
15+
value_info.set_name(name);
16+
17+
TypeProto_Tensor* tensor_type =
18+
value_info.mutable_type()->mutable_tensor_type();
19+
tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
20+
21+
TensorShapeProto* shape = tensor_type->mutable_shape();
22+
shape->add_dim()->set_dim_value(n);
23+
shape->add_dim()->set_dim_value(c);
24+
shape->add_dim()->set_dim_value(h);
25+
shape->add_dim()->set_dim_value(w);
26+
}
27+
28+
inline void createValueInfo2D(
29+
ValueInfoProto& value_info,
30+
const std::string& name,
31+
int64_t h,
32+
int64_t w) {
33+
value_info.set_name(name);
34+
35+
TypeProto* type = value_info.mutable_type();
36+
37+
TypeProto_Tensor* tensor_type = type->mutable_tensor_type();
38+
tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
39+
TensorShapeProto* shape = tensor_type->mutable_shape();
40+
shape->add_dim()->set_dim_value(h);
41+
shape->add_dim()->set_dim_value(w);
42+
}
43+
44+
inline void createConv2D(
45+
NodeProto& node,
46+
const std::string& input,
47+
const std::string& weights,
48+
const std::string& bias,
49+
const std::string& output,
50+
uint32_t kernel_size) {
51+
node.set_op_type("Conv");
52+
node.add_input(input);
53+
node.add_input(weights);
54+
node.add_input(bias);
55+
node.add_output(output);
56+
57+
{
58+
AttributeProto* kernel = node.add_attribute();
59+
kernel->set_name("kernel_shape");
60+
kernel->set_type(AttributeProto::INTS);
61+
kernel->add_ints(kernel_size);
62+
kernel->add_ints(kernel_size);
63+
}
64+
{
65+
AttributeProto* dilation = node.add_attribute();
66+
dilation->set_name("dilations");
67+
dilation->set_type(AttributeProto::INTS);
68+
dilation->add_ints(1);
69+
dilation->add_ints(1);
70+
}
71+
{
72+
AttributeProto* stride = node.add_attribute();
73+
stride->set_name("strides");
74+
stride->set_type(AttributeProto::INTS);
75+
stride->add_ints(1);
76+
stride->add_ints(1);
77+
}
78+
{
79+
AttributeProto* group = node.add_attribute();
80+
group->set_name("group");
81+
group->set_type(AttributeProto::INTS);
82+
group->set_i(1);
83+
}
84+
{
85+
AttributeProto* padding = node.add_attribute();
86+
padding->set_name("pads");
87+
padding->set_type(AttributeProto::INTS);
88+
/* Use "same" padding */
89+
padding->add_ints(kernel_size / 2);
90+
padding->add_ints(kernel_size / 2);
91+
padding->add_ints(kernel_size - 1 - kernel_size / 2);
92+
padding->add_ints(kernel_size - 1 - kernel_size / 2);
93+
}
94+
}
95+
96+
static void ConvGraph(benchmark::State& state) {
97+
while (state.KeepRunning()) {
98+
std::string data;
99+
GraphProto graph;
100+
101+
createConv2D(*graph.add_node(), "input", "weights", "bias", "output", 3);
102+
103+
createValueInfo4D(*graph.add_input(), "input", 1, 3, 224, 224);
104+
createValueInfo4D(*graph.add_input(), "weights", 16, 16, 3, 3);
105+
createValueInfo2D(*graph.add_input(), "bias", 1, 16);
106+
createValueInfo4D(*graph.add_output(), "output", 16, 3, 224, 224);
107+
108+
graph.SerializeToString(&data);
109+
110+
GraphProto decodedGraph;
111+
decodedGraph.ParseFromString(data);
112+
}
113+
114+
state.SetItemsProcessed(int64_t(state.iterations()));
115+
}
116+
BENCHMARK(ConvGraph)->Unit(benchmark::kMicrosecond);
117+
118+
static void ConvModel(benchmark::State& state) {
119+
while (state.KeepRunning()) {
120+
std::string data;
121+
ModelProto model;
122+
model.set_ir_version(IR_VERSION);
123+
OperatorSetIdProto* op_set_id = model.add_opset_import();
124+
op_set_id->set_domain("");
125+
op_set_id->set_version(4);
126+
127+
GraphProto* graph = model.mutable_graph();
128+
129+
createConv2D(*graph->add_node(), "input", "weights", "bias", "output", 3);
130+
131+
createValueInfo4D(*graph->add_input(), "input", 1, 3, 224, 224);
132+
createValueInfo4D(*graph->add_input(), "weights", 16, 16, 3, 3);
133+
createValueInfo2D(*graph->add_input(), "bias", 1, 16);
134+
createValueInfo4D(*graph->add_output(), "output", 16, 3, 224, 224);
135+
136+
model.SerializeToString(&data);
137+
138+
ModelProto decodedModel;
139+
decodedModel.ParseFromString(data);
140+
}
141+
142+
state.SetItemsProcessed(int64_t(state.iterations()));
143+
}
144+
BENCHMARK(ConvModel)->Unit(benchmark::kMicrosecond);
145+
146+
BENCHMARK_MAIN();

0 commit comments

Comments
 (0)