Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1252][1 of 2] Decouple NNVM to ONNX from NNVM to TenosrRT conv…
Browse files Browse the repository at this point in the history
…ersion
  • Loading branch information
haohuanw committed Dec 17, 2018
1 parent aa240cb commit 29f7a53
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 76 deletions.
8 changes: 4 additions & 4 deletions src/executor/tensorrt_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
std::unordered_map<std::string, NDArray>* const params_map) {
auto p = nnvm::Node::Create();
p->attrs.op = nnvm::Op::Get("_trt_op");
op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map;
p->attrs.dict["serialized_input_map"] = trt_param.serialized_input_map;
p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph;
op::ONNXParam onnx_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map);
p->attrs.dict["serialized_output_map"] = onnx_param.serialized_output_map;
p->attrs.dict["serialized_input_map"] = onnx_param.serialized_input_map;
p->attrs.dict["serialized_onnx_graph"] = onnx_param.serialized_onnx_graph;
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
Expand Down
42 changes: 38 additions & 4 deletions src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include <nnvm/graph.h>
#include <nnvm/pass_functions.h>

#include <NvInfer.h>
#include <onnx/onnx.pb.h>

#include <algorithm>
Expand All @@ -49,13 +48,48 @@
#include <utility>
#include <string>

#include "./tensorrt-inl.h"
#include "../operator_common.h"
#include "../../common/utils.h"
#include "../../common/serialization.h"

namespace mxnet {
namespace op {

namespace nnvm_to_onnx {
enum class TypeIO { Inputs = 0, Outputs = 1 };
using NameToIdx_t = std::map<std::string, int32_t>;
using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
} // namespace onnx

struct ONNXParam : public dmlc::Parameter<ONNXParam> {
std::string serialized_onnx_graph;
std::string serialized_input_map;
std::string serialized_output_map;
nnvm_to_onnx::NameToIdx_t input_map;
nnvm_to_onnx::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

ONNXParam() {}

ONNXParam(const ::onnx::ModelProto& onnx_graph,
const nnvm_to_onnx::InferenceMap_t& input_map,
const nnvm_to_onnx::NameToIdx_t& output_map) {
common::Serialize(input_map, &serialized_input_map);
common::Serialize(output_map, &serialized_output_map);
onnx_graph.SerializeToString(&serialized_onnx_graph);
}

DMLC_DECLARE_PARAMETER(ONNXParam) {
DMLC_DECLARE_FIELD(serialized_onnx_graph)
.describe("Serialized ONNX graph");
DMLC_DECLARE_FIELD(serialized_input_map)
.describe("Map from inputs to topological order as input.");
DMLC_DECLARE_FIELD(serialized_output_map)
.describe("Map from outputs to order in g.outputs.");
}
};

namespace nnvm_to_onnx {

using namespace nnvm;
Expand All @@ -76,7 +110,7 @@ void ConvertConstant(GraphProto* const graph_proto,
const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer);

void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map,
void ConvertOutput(op::nnvm_to_onnx::InferenceMap_t* const trt_output_map,
GraphProto* const graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name,
Expand Down Expand Up @@ -133,7 +167,7 @@ void ConvertElementwiseAdd(NodeProto *node_proto,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);

TRTParam ConvertNnvmGraphToOnnx(
ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph &g,
std::unordered_map<std::string, NDArray> *const shared_buffer);

Expand Down
34 changes: 18 additions & 16 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,24 @@
#include "../../operator/nn/fully_connected-inl.h"
#include "../../operator/nn/pooling-inl.h"
#include "../../operator/softmax_output-inl.h"
#include "./tensorrt-inl.h"

#if MXNET_USE_TENSORRT_ONNX_CHECKER
#include <onnx/checker.h>
#endif // MXNET_USE_TENSORRT_ONNX_CHECKER

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(ONNXParam);

namespace nnvm_to_onnx {

op::TRTParam ConvertNnvmGraphToOnnx(
op::ONNXParam ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
op::TRTParam trt_param;
op::tensorrt::NameToIdx_t trt_input_map;
op::tensorrt::InferenceMap_t trt_output_map;
op::ONNXParam onnx_param;
op::nnvm_to_onnx::NameToIdx_t onnx_input_map;
op::nnvm_to_onnx::InferenceMap_t onnx_output_map;

const nnvm::IndexedGraph& ig = g.indexed_graph();
const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
Expand Down Expand Up @@ -105,7 +107,7 @@ op::TRTParam ConvertNnvmGraphToOnnx(
current_input++;
continue;
}
trt_input_map.emplace(node_name, current_input++);
onnx_input_map.emplace(node_name, current_input++);
ConvertPlaceholder(node_name, placeholder_shapes, graph_proto);
} else {
// If it's not a placeholder, then by exclusion it's a constant.
Expand Down Expand Up @@ -140,23 +142,23 @@ op::TRTParam ConvertNnvmGraphToOnnx(
auto out_iter = output_lookup.find(node_name);
// We found an output
if (out_iter != output_lookup.end()) {
ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g,
ConvertOutput(&onnx_output_map, graph_proto, out_iter, node_name, g,
storage_types, dtypes);
} // output found
} // conversion function exists
} // loop over i from 0 to num_nodes

model_proto.SerializeToString(&trt_param.serialized_onnx_graph);
common::Serialize<op::tensorrt::NameToIdx_t>(trt_input_map,
&trt_param.serialized_input_map);
common::Serialize<op::tensorrt::InferenceMap_t>(trt_output_map,
&trt_param.serialized_output_map);
model_proto.SerializeToString(&onnx_param.serialized_onnx_graph);
common::Serialize<op::nnvm_to_onnx::NameToIdx_t>(onnx_input_map,
&onnx_param.serialized_input_map);
common::Serialize<op::nnvm_to_onnx::InferenceMap_t>(onnx_output_map,
&onnx_param.serialized_output_map);

#if MXNET_USE_TENSORRT_ONNX_CHECKER
onnx::checker::check_model(model_proto);
#endif // MXNET_USE_TENSORRT_ONNX_CHECKER

return trt_param;
return onnx_param;
}

void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
Expand Down Expand Up @@ -489,7 +491,7 @@ void ConvertConstant(
}

void ConvertOutput(
op::tensorrt::InferenceMap_t* const trt_output_map,
op::nnvm_to_onnx::InferenceMap_t* const output_map,
GraphProto* const graph_proto,
const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
const std::string& node_name, const nnvm::Graph& g,
Expand All @@ -501,10 +503,10 @@ void ConvertOutput(
int dtype = dtypes[out_idx];

// This should work with fp16 as well
op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
op::nnvm_to_onnx::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type,
dtype};

trt_output_map->emplace(node_name, out_tuple);
output_map->emplace(node_name, out_tuple);

auto graph_out = graph_proto->add_output();
auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
Expand Down
38 changes: 2 additions & 36 deletions src/operator/contrib/tensorrt-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include <nnvm/pass_functions.h>

#include <NvInfer.h>
#include <onnx/onnx.pb.h>

#include <algorithm>
#include <iostream>
Expand All @@ -49,6 +48,7 @@
#include <utility>
#include <string>

#include "nnvm_to_onnx-inl.h"
#include "../operator_common.h"
#include "../../common/utils.h"
#include "../../common/serialization.h"
Expand All @@ -60,49 +60,15 @@ namespace mxnet {
namespace op {

using namespace nnvm;
using namespace ::onnx;
using int64 = ::google::protobuf::int64;

namespace tensorrt {
enum class TypeIO { Inputs = 0, Outputs = 1 };
using NameToIdx_t = std::map<std::string, int32_t>;
using InferenceTuple_t = std::tuple<uint32_t, TShape, int, int>;
using InferenceMap_t = std::map<std::string, InferenceTuple_t>;
} // namespace tensorrt

using trt_name_to_idx = std::map<std::string, uint32_t>;

struct TRTParam : public dmlc::Parameter<TRTParam> {
std::string serialized_onnx_graph;
std::string serialized_input_map;
std::string serialized_output_map;
tensorrt::NameToIdx_t input_map;
tensorrt::InferenceMap_t output_map;
::onnx::ModelProto onnx_pb_graph;

TRTParam() {}

TRTParam(const ::onnx::ModelProto& onnx_graph,
const tensorrt::InferenceMap_t& input_map,
const tensorrt::NameToIdx_t& output_map) {
common::Serialize(input_map, &serialized_input_map);
common::Serialize(output_map, &serialized_output_map);
onnx_graph.SerializeToString(&serialized_onnx_graph);
}

DMLC_DECLARE_PARAMETER(TRTParam) {
DMLC_DECLARE_FIELD(serialized_onnx_graph)
.describe("Serialized ONNX graph");
DMLC_DECLARE_FIELD(serialized_input_map)
.describe("Map from inputs to topological order as input.");
DMLC_DECLARE_FIELD(serialized_output_map)
.describe("Map from outputs to order in g.outputs.");
}
};

struct TRTEngineParam {
nvinfer1::IExecutionContext* trt_executor;
std::vector<std::pair<uint32_t, tensorrt::TypeIO> > binding_map;
std::vector<std::pair<uint32_t, nnvm_to_onnx::TypeIO> > binding_map;
};

} // namespace op
Expand Down
28 changes: 13 additions & 15 deletions src/operator/contrib/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,18 @@
namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(TRTParam);

OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
tensorrt::NameToIdx_t input_map,
tensorrt::NameToIdx_t output_map) {
nnvm_to_onnx::NameToIdx_t input_map,
nnvm_to_onnx::NameToIdx_t output_map) {
TRTEngineParam param;
for (int b = 0; b < trt_engine->getNbBindings(); ++b) {
const std::string& binding_name = trt_engine->getBindingName(b);
if (trt_engine->bindingIsInput(b)) {
param.binding_map.emplace_back(input_map[binding_name],
tensorrt::TypeIO::Inputs);
nnvm_to_onnx::TypeIO::Inputs);
} else {
param.binding_map.emplace_back(output_map[binding_name],
tensorrt::TypeIO::Outputs);
nnvm_to_onnx::TypeIO::Outputs);
}
}
param.trt_executor = trt_engine->createExecutionContext();
Expand All @@ -67,7 +65,7 @@ OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine,
OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
const std::vector<TShape>& /*ishape*/,
const std::vector<int>& /*itype*/) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);

::onnx::ModelProto model_proto;
bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph);
Expand All @@ -82,15 +80,15 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/,
nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx(
node_param.serialized_onnx_graph, batch_size, 1 << 30);

tensorrt::NameToIdx_t output_map;
nnvm_to_onnx::NameToIdx_t output_map;
for (auto& el : node_param.output_map) {
output_map[el.first] = std::get<0>(el.second);
}
return GetPtrMapping(trt_engine, node_param.input_map, output_map);
}

void TRTParamParser(nnvm::NodeAttrs* attrs) {
TRTParam param_;
ONNXParam param_;

try {
param_.Init(attrs->dict);
Expand All @@ -114,7 +112,7 @@ void TRTParamParser(nnvm::NodeAttrs* attrs) {

inline bool TRTInferShape(const NodeAttrs& attrs, std::vector<TShape>* /*in_shape*/,
std::vector<TShape>* out_shape) {
const auto &node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto &node_param = nnvm::get<ONNXParam>(attrs.parsed);
for (auto& el : node_param.output_map) {
(*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second);
}
Expand All @@ -131,7 +129,7 @@ inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask

inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,
std::vector<int>* out_dtype) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
for (auto& el : node_param.output_map) {
(*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second);
}
Expand All @@ -140,7 +138,7 @@ inline bool TRTInferType(const NodeAttrs& attrs, std::vector<int>* /*in_dtype*/,

inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {
std::vector<std::string> output;
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
output.resize(node_param.input_map.size());
for (auto& el : node_param.input_map) {
output[el.second] = el.first;
Expand All @@ -150,7 +148,7 @@ inline std::vector<std::string> TRTListInputNames(const NodeAttrs& attrs) {

inline std::vector<std::string> TRTListOutputNames(const NodeAttrs& attrs) {
std::vector<std::string> output;
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
output.resize(node_param.output_map.size());
for (auto& el : node_param.output_map) {
output[std::get<0>(el.second)] = el.first;
Expand All @@ -162,11 +160,11 @@ NNVM_REGISTER_OP(_trt_op)
.describe(R"code(TRT operation (one engine)
)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
return node_param.input_map.size();
})
.set_num_outputs([](const NodeAttrs& attrs) {
const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
const auto& node_param = nnvm::get<ONNXParam>(attrs.parsed);
return node_param.output_map.size();
})
.set_attr_parser(TRTParamParser)
Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/tensorrt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
std::vector<void*> bindings;
bindings.reserve(param.binding_map.size());
for (auto& p : param.binding_map) {
if (p.second == tensorrt::TypeIO::Inputs) {
if (p.second == nnvm_to_onnx::TypeIO::Inputs) {
bindings.emplace_back(inputs[p.first].dptr_);
} else {
bindings.emplace_back(outputs[p.first].dptr_);
Expand Down

0 comments on commit 29f7a53

Please sign in to comment.