Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
prevent TRT_Logger to be destroyed before TRT engine (#14898) (#15877)
Browse files Browse the repository at this point in the history
* prevent TRT_Logger to be destroyed before TRT engine

* use unique_ptr for trt_logger/parser/engine/executor ownership

* reduce line length for lint
  • Loading branch information
KellenSunderland authored and TaoLv committed Aug 16, 2019
1 parent 964f288 commit bd2b5a2
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 61 deletions.
35 changes: 11 additions & 24 deletions src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,6 @@ using std::endl;

namespace onnx_to_tensorrt {

struct InferDeleter {
template<typename T>
void operator()(T* obj) const {
if ( obj ) {
obj->destroy();
}
}
};

template<typename T>
inline std::shared_ptr<T> InferObject(T* obj) {
if ( !obj ) {
throw std::runtime_error("Failed to create object");
}
return std::shared_ptr<T>(obj, InferDeleter());
}

std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) {
int onnx_ir_major = ir_version / 1000000;
int onnx_ir_minor = ir_version % 1000000 / 10000;
Expand All @@ -83,18 +66,20 @@ void PrintVersion() {
<< NV_TENSORRT_PATCH << endl;
}

std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
unique_ptr<nvonnxparser::IParser>,
std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
const std::string& onnx_model,
int32_t max_batch_size,
size_t max_workspace_size,
nvinfer1::ILogger::Severity verbosity,
bool debug_builder) {
GOOGLE_PROTOBUF_VERIFY_VERSION;

TRT_Logger trt_logger(verbosity);
auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger));
auto trt_network = InferObject(trt_builder->createNetwork());
auto trt_parser = nvonnxparser::createParser(trt_network.get(), trt_logger);
auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
auto trt_builder = nvinfer1::createInferBuilder(*trt_logger);
auto trt_network = trt_builder->createNetwork();
auto trt_parser = InferObject(nvonnxparser::createParser(trt_network, *trt_logger));
::ONNX_NAMESPACE::ModelProto parsed_model;
// We check for a valid parse, but the main effect is the side effect
// of populating parsed_model
Expand Down Expand Up @@ -139,8 +124,10 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
trt_builder->setMaxBatchSize(max_batch_size);
trt_builder->setMaxWorkspaceSize(max_workspace_size);
trt_builder->setDebugSync(debug_builder);
nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get());
return std::make_tuple(trt_engine, trt_parser);
auto trt_engine = InferObject(trt_builder->buildCudaEngine(*trt_network));
trt_builder->destroy();
trt_network->destroy();
return std::make_tuple(std::move(trt_engine), std::move(trt_parser), std::move(trt_logger));
}

} // namespace onnx_to_tensorrt
Expand Down
66 changes: 42 additions & 24 deletions src/operator/subgraph/tensorrt/onnx_to_tensorrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <NvInfer.h>

#include <fstream>
#include <memory>
#include <iostream>
#include <sstream>
#include <string>
Expand All @@ -40,33 +41,51 @@

namespace onnx_to_tensorrt {

struct InferDeleter {
template<typename T>
void operator()(T* obj) const {
if ( obj ) {
obj->destroy();
}
}
};

template<typename T>
using unique_ptr = std::unique_ptr<T, InferDeleter>;

template<typename T>
inline unique_ptr<T> InferObject(T* obj) {
if ( !obj ) {
throw std::runtime_error("Failed to create object");
}
return unique_ptr<T>(obj, InferDeleter());
}

class TRT_Logger : public nvinfer1::ILogger {
nvinfer1::ILogger::Severity _verbosity;
std::ostream* _ostream;
nvinfer1::ILogger::Severity _verbosity;
std::ostream* _ostream;
public:
TRT_Logger(Severity verbosity = Severity::kWARNING,
std::ostream& ostream = std::cout)
: _verbosity(verbosity), _ostream(&ostream) {}
void log(Severity severity, const char* msg) override {
if ( severity <= _verbosity ) {
time_t rawtime = std::time(0);
char buf[256];
strftime(&buf[0], 256,
"%Y-%m-%d %H:%M:%S",
std::gmtime(&rawtime));
const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" :
severity == Severity::kERROR ? " ERROR" :
severity == Severity::kWARNING ? "WARNING" :
severity == Severity::kINFO ? " INFO" :
"UNKNOWN");
(*_ostream) << "[" << buf << " " << sevstr << "] "
<< msg
<< std::endl;
}
}
TRT_Logger(Severity verbosity = Severity::kWARNING,
std::ostream& ostream = std::cout) :
_verbosity(verbosity), _ostream(&ostream) {}
void log(Severity severity, const char* msg) override {
if (severity <= _verbosity) {
time_t rawtime = std::time(0);
char buf[256];
strftime(&buf[0], 256, "%Y-%m-%d %H:%M:%S", std::gmtime(&rawtime));
const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" :
severity == Severity::kERROR ? " ERROR" :
severity == Severity::kWARNING ? "WARNING" :
severity == Severity::kINFO ? " INFO" :
"UNKNOWN");
(*_ostream) << "[" << buf << " " << sevstr << "] " << msg << std::endl;
}
}
};

std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
unique_ptr<nvonnxparser::IParser>,
std::unique_ptr<TRT_Logger> > onnxToTrtCtx(
const std::string& onnx_model,
int32_t max_batch_size = 32,
size_t max_workspace_size = 1L << 30,
Expand All @@ -75,5 +94,4 @@ std::tuple<nvinfer1::ICudaEngine*, nvonnxparser::IParser*> onnxToTrtCtx(
} // namespace onnx_to_tensorrt

#endif // MXNET_USE_TENSORRT

#endif // MXNET_OPERATOR_SUBGRAPH_TENSORRT_ONNX_TO_TENSORRT_H_
25 changes: 13 additions & 12 deletions src/operator/subgraph/tensorrt/tensorrt-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ struct TRTParam {
};

struct TRTEngineParam {
TRTEngineParam(nvinfer1::ICudaEngine* trt_engine,
nvonnxparser::IParser* _parser,
const std::unordered_map<std::string, uint32_t> input_map,
const std::unordered_map<std::string, uint32_t> output_map) {
TRTEngineParam(onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> _trt_engine,
onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> _trt_parser,
std::unique_ptr<onnx_to_tensorrt::TRT_Logger> _trt_logger,
const std::unordered_map<std::string, uint32_t>& input_map,
const std::unordered_map<std::string, uint32_t>& output_map) {
trt_engine = std::move(_trt_engine);
trt_logger = std::move(_trt_logger);
trt_parser = std::move(_trt_parser);
binding_order = std::make_shared<std::vector<std::pair<uint32_t, bool> > >();
bindings = std::make_shared<std::vector<void*> >();
binding_order->reserve(trt_engine->getNbBindings());
Expand All @@ -67,16 +71,13 @@ struct TRTEngineParam {
binding_order->emplace_back(output_map.at(binding_name), false);
}
}
trt_executor = trt_engine->createExecutionContext();
trt_parser = _parser;
trt_executor = onnx_to_tensorrt::InferObject(trt_engine->createExecutionContext());
}

~TRTEngineParam() {
trt_parser->destroy();
trt_executor->destroy();
}
nvinfer1::IExecutionContext* trt_executor;
nvonnxparser::IParser* trt_parser;
onnx_to_tensorrt::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
onnx_to_tensorrt::unique_ptr<nvinfer1::IExecutionContext> trt_executor;
onnx_to_tensorrt::unique_ptr<nvonnxparser::IParser> trt_parser;
std::unique_ptr<onnx_to_tensorrt::TRT_Logger> trt_logger;
std::shared_ptr<std::vector<std::pair<uint32_t, bool> > > binding_order;
std::shared_ptr<std::vector<void*> > bindings;
};
Expand Down
4 changes: 3 additions & 1 deletion src/operator/subgraph/tensorrt/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
graph.attrs["shape"] = std::make_shared<nnvm::any>(std::move(shapes));
auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph, &params_map);
auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, max_batch_size, 1 << 30);
return OpStatePtr::Create<TRTEngineParam>(std::get<0>(trt_tuple), std::get<1>(trt_tuple),
return OpStatePtr::Create<TRTEngineParam>(std::move(std::get<0>(trt_tuple)),
std::move(std::get<1>(trt_tuple)),
std::move(std::get<2>(trt_tuple)),
inputs_to_idx, outputs_to_idx);
}

Expand Down

0 comments on commit bd2b5a2

Please sign in to comment.