Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-703] Minor refactor of TensorRT code #13311

Merged
merged 1 commit into from
Jan 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/executor/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ nvinfer1::ICudaEngine* onnxToTrtCtx(
}

if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
int nerror = trt_parser->getNbErrors();
for ( int i=0; i < nerror; ++i ) {
size_t nerror = trt_parser->getNbErrors();
for ( size_t i=0; i < nerror; ++i ) {
KellenSunderland marked this conversation as resolved.
Show resolved Hide resolved
nvonnxparser::IParserError const* error = trt_parser->getError(i);
if ( error->node() != -1 ) {
::ONNX_NAMESPACE::NodeProto const& node =
Expand Down
7 changes: 3 additions & 4 deletions src/executor/trt_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}

auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
for (auto trt_group : trt_groups) {
for (const auto &trt_group : trt_groups) {
if (trt_group.size() > 1) {
g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
Expand All @@ -142,7 +142,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}
}


InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<StorageTypeVector>("storage_type"),
Expand Down Expand Up @@ -188,7 +187,7 @@ void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const uint32_t eid = idx.entry_id(nid, 0);
const TShape& inferred_shape = inferred_shapes[eid];
const int inferred_dtype = inferred_dtypes[eid];
const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
const auto inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
KellenSunderland marked this conversation as resolved.
Show resolved Hide resolved
const std::string& arg_name = idx[nid].source->attrs.name;
// aux_states
if (mutable_nodes.count(nid)) {
Expand Down Expand Up @@ -427,7 +426,7 @@ Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray> *shared_buffer,
Executor *shared_exec) {
auto exec = new exec::TrtGraphExecutor();
exec->Init(symbol, default_ctx, group2ctx,
exec->Init(std::move(symbol), default_ctx, group2ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Init accepting an rvalue reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good question I wanted to dive into a little. My understanding is that the Init function can bind its variable to either an lvalue or an rvalue when passing by value. You don't have to create an overload or anything to determine which one it expects. Here's an example:

https://gist.github.com/KellenSunderland/d431119580a116410c672b9535d85170
Which is using code borrowed from:
https://www.chromium.org/rvalue-references?tmpl=%2Fsystem%2Fapp%2Ftemplates%2Fprint%2F&showPrintDialog=1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the question if the move is doing anything, because I don't know if there's an Init signature accepting an rvalue reference to a symbol, that was the original question.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was also curious about this. The example shows a similar pattern and it does save a copy.

in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_types, param_names,
Expand Down
14 changes: 7 additions & 7 deletions src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ONNXParam : public dmlc::Parameter<ONNXParam> {
nnvm_to_onnx::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

ONNXParam() {}
ONNXParam() = default;

ONNXParam(const ::onnx::ModelProto& onnx_graph,
const nnvm_to_onnx::InferenceMap_t& input_map,
Expand Down Expand Up @@ -104,14 +104,14 @@ std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGra
void ConvertPlaceholder(
const std::string& node_name,
const std::unordered_map<std::string, TShape>& placeholder_shapes,
GraphProto* const graph_proto);
GraphProto* graph_proto);

void ConvertConstant(GraphProto* const graph_proto,
void ConvertConstant(GraphProto* graph_proto,
const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
GraphProto* const graph_proto,
void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* trt_output_map,
GraphProto* graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name,
const nnvm::Graph& g,
Expand Down Expand Up @@ -169,7 +169,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,

ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph &g,
std::unordered_map<std::string, NDArray> *const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

static const std::unordered_map<std::string, ConverterFunction> converter_map = {
{"Convolution", ConvertConvolution},
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const kernel_shape = node_proto->add_attribute();
kernel_shape->set_name("kernel_shape");
kernel_shape->set_type(AttributeProto::INTS);
for (int kval : kernel) {
for (dim_t kval : kernel) {
kernel_shape->add_ints(static_cast<int64>(kval));
}

Expand All @@ -283,7 +283,7 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const strides = node_proto->add_attribute();
strides->set_name("strides");
strides->set_type(AttributeProto::INTS);
for (int kval : stride) {
for (dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}

Expand Down