Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-703] Minor refactor of TensorRT code
Browse files Browse the repository at this point in the history
  • Loading branch information
KellenSunderland committed Jan 16, 2019
1 parent cc15d9a commit 247c322
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/executor/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ nvinfer1::ICudaEngine* onnxToTrtCtx(
}

if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
int nerror = trt_parser->getNbErrors();
for ( int i=0; i < nerror; ++i ) {
size_t nerror = trt_parser->getNbErrors();
for ( size_t i=0; i < nerror; ++i ) {
nvonnxparser::IParserError const* error = trt_parser->getError(i);
if ( error->node() != -1 ) {
::ONNX_NAMESPACE::NodeProto const& node =
Expand Down
7 changes: 3 additions & 4 deletions src/executor/trt_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}

auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
for (auto trt_group : trt_groups) {
for (const auto &trt_group : trt_groups) {
if (trt_group.size() > 1) {
g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
Expand All @@ -142,7 +142,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}
}


InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<StorageTypeVector>("storage_type"),
Expand Down Expand Up @@ -188,7 +187,7 @@ void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const uint32_t eid = idx.entry_id(nid, 0);
const TShape& inferred_shape = inferred_shapes[eid];
const int inferred_dtype = inferred_dtypes[eid];
const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
const auto inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
const std::string& arg_name = idx[nid].source->attrs.name;
// aux_states
if (mutable_nodes.count(nid)) {
Expand Down Expand Up @@ -427,7 +426,7 @@ Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray> *shared_buffer,
Executor *shared_exec) {
auto exec = new exec::TrtGraphExecutor();
exec->Init(symbol, default_ctx, group2ctx,
exec->Init(std::move(symbol), default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_types, param_names,
Expand Down
14 changes: 7 additions & 7 deletions src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ONNXParam : public dmlc::Parameter<ONNXParam> {
nnvm_to_onnx::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

ONNXParam() {}
ONNXParam() = default;

ONNXParam(const ::onnx::ModelProto& onnx_graph,
const nnvm_to_onnx::InferenceMap_t& input_map,
Expand Down Expand Up @@ -104,14 +104,14 @@ std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGra
void ConvertPlaceholder(
const std::string& node_name,
const std::unordered_map<std::string, TShape>& placeholder_shapes,
GraphProto* const graph_proto);
GraphProto* graph_proto);

void ConvertConstant(GraphProto* const graph_proto,
void ConvertConstant(GraphProto* graph_proto,
const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
GraphProto* const graph_proto,
void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* trt_output_map,
GraphProto* graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name,
const nnvm::Graph& g,
Expand Down Expand Up @@ -169,7 +169,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,

ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph &g,
std::unordered_map<std::string, NDArray> *const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

static const std::unordered_map<std::string, ConverterFunction> converter_map = {
{"Convolution", ConvertConvolution},
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const kernel_shape = node_proto->add_attribute();
kernel_shape->set_name("kernel_shape");
kernel_shape->set_type(AttributeProto::INTS);
for (int kval : kernel) {
for (dim_t kval : kernel) {
kernel_shape->add_ints(static_cast<int64>(kval));
}

Expand All @@ -283,7 +283,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const strides = node_proto->add_attribute();
strides->set_name("strides");
strides->set_type(AttributeProto::INTS);
for (int kval : stride) {
for (dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}

Expand Down

0 comments on commit 247c322

Please sign in to comment.