diff --git a/src/operator/subgraph/tensorrt/tensorrt.cc b/src/operator/subgraph/tensorrt/tensorrt.cc index e872e9274f21..8395fb43f1d9 100644 --- a/src/operator/subgraph/tensorrt/tensorrt.cc +++ b/src/operator/subgraph/tensorrt/tensorrt.cc @@ -28,6 +28,8 @@ #include "./tensorrt-inl.h" +#include + namespace mxnet { namespace op { @@ -311,7 +313,13 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx, graph.attrs["dtype"] = std::make_shared(std::move(dtypes)); graph.attrs["shape"] = std::make_shared(std::move(shapes)); auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph, ¶ms_map); - auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, max_batch_size, 1 << 30); + uint32_t verbose = dmlc::GetEnv("MXNET_TENSORRT_VERBOSE", 0); + auto log_lvl = nvinfer1::ILogger::Severity::kWARNING; + if (verbose != 0) { + log_lvl = nvinfer1::ILogger::Severity::kVERBOSE; + } + + auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, max_batch_size, 1 << 30, log_lvl); return OpStatePtr::Create(std::move(std::get<0>(trt_tuple)), std::move(std::get<1>(trt_tuple)), std::move(std::get<2>(trt_tuple)),