forked from IRCVLab/AUE8088-PA1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorRT.py
31 lines (22 loc) · 1012 Bytes
/
tensorRT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
def converter(onnx_filename, trt_filename, half):
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(EXPLICIT_BATCH)
parser = trt.OnnxParser(network, TRT_LOGGER)
builder_config = builder.create_builder_config()
builder_config.max_workspace_size = 3 << 30
if half:
builder_config.set_flag(trt.BuilderFlag.FP16)
with open(onnx_filename, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print (parser.get_error(error))
plan = builder.build_serialized_network(network, builder_config)
with trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(plan)
with open(trt_filename, 'wb') as f:
f.write(engine.serialize())