Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-703] Minor refactor of TensorRT code
Browse files Browse the repository at this point in the history
  • Loading branch information
KellenSunderland committed Nov 17, 2018
1 parent 6ae5b65 commit b422b7d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/executor/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ nvinfer1::ICudaEngine* onnxToTrtCtx(
}

if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
int nerror = trt_parser->getNbErrors();
for ( int i=0; i < nerror; ++i ) {
size_t nerror = trt_parser->getNbErrors();
for ( size_t i=0; i < nerror; ++i ) {
nvonnxparser::IParserError const* error = trt_parser->getError(i);
if ( error->node() != -1 ) {
::ONNX_NAMESPACE::NodeProto const& node =
Expand Down
7 changes: 4 additions & 3 deletions src/executor/trt_graph_executor.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <utility>

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down Expand Up @@ -133,7 +135,7 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}

auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
for (auto trt_group : trt_groups) {
for (const auto &trt_group : trt_groups) {
if (trt_group.size() > 1) {
g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
Expand All @@ -142,7 +144,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}
}


InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<StorageTypeVector>("storage_type"),
Expand Down Expand Up @@ -434,7 +435,7 @@ Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray> *shared_buffer,
Executor *shared_exec) {
auto exec = new exec::TrtGraphExecutor();
exec->Init(symbol, default_ctx, group2ctx,
exec->Init(std::move(symbol), default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_types, param_names,
Expand Down
12 changes: 6 additions & 6 deletions src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGra
void ConvertPlaceholder(
const std::string& node_name,
const std::unordered_map<std::string, TShape>& placeholder_shapes,
GraphProto* const graph_proto);
GraphProto* graph_proto);

void ConvertConstant(GraphProto* const graph_proto,
void ConvertConstant(GraphProto* graph_proto,
const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
GraphProto* const graph_proto,
void ConvertOutput(op::tensorrt::InferenceMap_t* trt_output_map,
GraphProto* graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name,
const nnvm::Graph& g,
Expand Down Expand Up @@ -135,7 +135,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,

TRTParam ConvertNnvmGraphToOnnx(
const nnvm::Graph &g,
std::unordered_map<std::string, NDArray> *const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

static const std::unordered_map<std::string, ConverterFunction> converter_map = {
{"Convolution", ConvertConvolution},
Expand Down
10 changes: 5 additions & 5 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ namespace nnvm_to_onnx {
op::TRTParam ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
op::TRTParam trt_param;
op::tensorrt::NameToIdx_t trt_input_map;
op::tensorrt::InferenceMap_t trt_output_map;

op::TRTParam trt_param;
op::tensorrt::NameToIdx_t trt_input_map;
op::tensorrt::InferenceMap_t trt_output_map;
const nnvm::IndexedGraph& ig = g.indexed_graph();
const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -240,7 +240,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const kernel_shape = node_proto->add_attribute();
kernel_shape->set_name("kernel_shape");
kernel_shape->set_type(AttributeProto::INTS);
for (int kval : kernel) {
for (dim_t kval : kernel) {
kernel_shape->add_ints(static_cast<int64>(kval));
}

Expand All @@ -256,7 +256,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const strides = node_proto->add_attribute();
strides->set_name("strides");
strides->set_type(AttributeProto::INTS);
for (int kval : stride) {
for (dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}

Expand Down

0 comments on commit b422b7d

Please sign in to comment.