From b422b7d7a0f03fc1644f1829a510534427e9b9b3 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Fri, 16 Nov 2018 15:08:22 -0800 Subject: [PATCH] [MXNET-703] Minor refactor of TensorRT code --- src/executor/onnx_to_tensorrt.cc | 4 ++-- src/executor/trt_graph_executor.cc | 7 ++++--- src/operator/contrib/nnvm_to_onnx-inl.h | 12 ++++++------ src/operator/contrib/nnvm_to_onnx.cc | 10 +++++----- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc index e3a4ae868ce2..c557527f0b9b 100644 --- a/src/executor/onnx_to_tensorrt.cc +++ b/src/executor/onnx_to_tensorrt.cc @@ -100,8 +100,8 @@ nvinfer1::ICudaEngine* onnxToTrtCtx( } if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) { - int nerror = trt_parser->getNbErrors(); - for ( int i=0; i < nerror; ++i ) { + size_t nerror = trt_parser->getNbErrors(); + for ( size_t i=0; i < nerror; ++i ) { nvonnxparser::IParserError const* error = trt_parser->getError(i); if ( error->node() != -1 ) { ::ONNX_NAMESPACE::NodeProto const& node = diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index 65dbb29792e0..a5a054b85d3d 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -1,3 +1,5 @@ +#include + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -133,7 +135,7 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol, } auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); - for (auto trt_group : trt_groups) { + for (const auto &trt_group : trt_groups) { if (trt_group.size() > 1) { g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer); g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, @@ -142,7 +144,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol, } } - InitArguments(g.indexed_graph(), g.GetAttr("shape"), g.GetAttr("dtype"), g.GetAttr("storage_type"), @@ -434,7 +435,7 @@ Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol, std::unordered_map *shared_buffer, Executor *shared_exec) { auto exec = new exec::TrtGraphExecutor(); - exec->Init(symbol, default_ctx, group2ctx, + exec->Init(std::move(symbol), default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, param_names, diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h index 58f88b051433..990b056b6a47 100644 --- a/src/operator/contrib/nnvm_to_onnx-inl.h +++ b/src/operator/contrib/nnvm_to_onnx-inl.h @@ -70,14 +70,14 @@ std::unordered_map GetOutputLookup(const nnvm::IndexedGra void ConvertPlaceholder( const std::string& node_name, const std::unordered_map& placeholder_shapes, - GraphProto* const graph_proto); + GraphProto* graph_proto); -void ConvertConstant(GraphProto* const graph_proto, +void ConvertConstant(GraphProto* graph_proto, const std::string& node_name, - std::unordered_map* const shared_buffer); + std::unordered_map* shared_buffer); -void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map, - GraphProto* const graph_proto, +void ConvertOutput(op::tensorrt::InferenceMap_t* trt_output_map, + GraphProto* graph_proto, const std::unordered_map::iterator& out_iter, const std::string& node_name, const nnvm::Graph& g, @@ -135,7 +135,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto, TRTParam ConvertNnvmGraphToOnnx( const nnvm::Graph &g, - std::unordered_map *const shared_buffer); + std::unordered_map* shared_buffer); static const std::unordered_map converter_map = { {"Convolution", ConvertConvolution}, diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc index 902466614c7c..974eec10e2f7 100644 --- a/src/operator/contrib/nnvm_to_onnx.cc +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -60,10 +60,10 @@ namespace nnvm_to_onnx { op::TRTParam ConvertNnvmGraphToOnnx( const nnvm::Graph& g, std::unordered_map* const shared_buffer) { - op::TRTParam trt_param; - op::tensorrt::NameToIdx_t trt_input_map; - op::tensorrt::InferenceMap_t trt_output_map; + op::TRTParam trt_param; + op::tensorrt::NameToIdx_t trt_input_map; + op::tensorrt::InferenceMap_t trt_output_map; const nnvm::IndexedGraph& ig = g.indexed_graph(); const auto& storage_types = g.GetAttr("storage_type"); const auto& dtypes = g.GetAttr("dtype"); @@ -240,7 +240,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, AttributeProto* const kernel_shape = node_proto->add_attribute(); kernel_shape->set_name("kernel_shape"); kernel_shape->set_type(AttributeProto::INTS); - for (int kval : kernel) { + for (dim_t kval : kernel) { kernel_shape->add_ints(static_cast(kval)); } @@ -256,7 +256,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, AttributeProto* const strides = node_proto->add_attribute(); strides->set_name("strides"); strides->set_type(AttributeProto::INTS); - for (int kval : stride) { + for (dim_t kval : stride) { strides->add_ints(static_cast(kval)); }