Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-703] Minor refactor of TensorRT code
Browse files Browse the repository at this point in the history
  • Loading branch information
KellenSunderland committed Jan 15, 2019
1 parent 9d42812 commit 8d2d15d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/executor/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ nvinfer1::ICudaEngine* onnxToTrtCtx(
}

if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) {
int nerror = trt_parser->getNbErrors();
for ( int i=0; i < nerror; ++i ) {
size_t nerror = trt_parser->getNbErrors();
for ( size_t i=0; i < nerror; ++i ) {
nvonnxparser::IParserError const* error = trt_parser->getError(i);
if ( error->node() != -1 ) {
::ONNX_NAMESPACE::NodeProto const& node =
Expand Down
7 changes: 3 additions & 4 deletions src/executor/trt_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}

auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer);
for (auto trt_group : trt_groups) {
for (const auto &trt_group : trt_groups) {
if (trt_group.size() > 1) {
g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer);
g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
Expand All @@ -142,7 +142,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol,
}
}


InitArguments(g.indexed_graph(), g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<StorageTypeVector>("storage_type"),
Expand Down Expand Up @@ -188,7 +187,7 @@ void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const uint32_t eid = idx.entry_id(nid, 0);
const TShape& inferred_shape = inferred_shapes[eid];
const int inferred_dtype = inferred_dtypes[eid];
const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
const auto inferred_stype = static_cast<NDArrayStorageType>(inferred_stypes[eid]);
const std::string& arg_name = idx[nid].source->attrs.name;
// aux_states
if (mutable_nodes.count(nid)) {
Expand Down Expand Up @@ -427,7 +426,7 @@ Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray> *shared_buffer,
Executor *shared_exec) {
auto exec = new exec::TrtGraphExecutor();
exec->Init(symbol, default_ctx, group2ctx,
exec->Init(std::move(symbol), default_ctx, group2ctx,
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
arg_shape_map, arg_dtype_map, arg_stype_map,
grad_req_types, param_names,
Expand Down
14 changes: 7 additions & 7 deletions src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ONNXParam : public dmlc::Parameter<ONNXParam> {
nnvm_to_onnx::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

ONNXParam() {}
ONNXParam() = default;

ONNXParam(const ::onnx::ModelProto& onnx_graph,
const nnvm_to_onnx::InferenceMap_t& input_map,
Expand Down Expand Up @@ -104,14 +104,14 @@ std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGra
void ConvertPlaceholder(
const std::string& node_name,
const std::unordered_map<std::string, TShape>& placeholder_shapes,
GraphProto* const graph_proto);
GraphProto* graph_proto);

void ConvertConstant(GraphProto* const graph_proto,
void ConvertConstant(GraphProto* graph_proto,
const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
GraphProto* const graph_proto,
void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* trt_output_map,
GraphProto* graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name,
const nnvm::Graph& g,
Expand Down Expand Up @@ -169,7 +169,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,

ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph &g,
std::unordered_map<std::string, NDArray> *const shared_buffer);
std::unordered_map<std::string, NDArray>* shared_buffer);

static const std::unordered_map<std::string, ConverterFunction> converter_map = {
{"Convolution", ConvertConvolution},
Expand Down
17 changes: 9 additions & 8 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ namespace nnvm_to_onnx {
op::ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
op::ONNXParam onnx_param;
op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
op::nnvm_to_onnx::InferenceMap_t onnx_output_map;

op::ONNXParam onnx_param;
op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
op::nnvm_to_onnx::InferenceMap_t onnx_output_map;

const nnvm::IndexedGraph& ig = g.indexed_graph();
const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
Expand Down Expand Up @@ -242,23 +243,23 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const kernel_shape = node_proto->add_attribute();
kernel_shape->set_name("kernel_shape");
kernel_shape->set_type(AttributeProto::INTS);
for (int kval : kernel) {
for (dim_t kval : kernel) {
kernel_shape->add_ints(static_cast<int64>(kval));
}

// pads
AttributeProto* const pads = node_proto->add_attribute();
pads->set_name("pads");
pads->set_type(AttributeProto::INTS);
for (int kval : pad) {
for (dim_t kval : pad) {
pads->add_ints(static_cast<int64>(kval));
}

// strides
AttributeProto* const strides = node_proto->add_attribute();
strides->set_name("strides");
strides->set_type(AttributeProto::INTS);
for (int kval : stride) {
for (dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}

Expand Down Expand Up @@ -469,7 +470,7 @@ void ConvertConstant(
const NDArray nd = shared_buffer->find(node_name)->second;
const TBlob& blob = nd.data();
const TShape shape = blob.shape_;
const int32_t size = shape.Size();
const size_t size = shape.Size();

std::shared_ptr<float> shared_data_ptr(new float[size]);
float* const data_ptr = shared_data_ptr.get();
Expand All @@ -485,7 +486,7 @@ void ConvertConstant(
tensor_proto->add_dims(static_cast<int64>(dim));
}

for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
tensor_proto->add_float_data(data_ptr[blob_idx]);
}
}
Expand Down

0 comments on commit 8d2d15d

Please sign in to comment.