Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
enable TensorRT integration with cpp api
Browse files Browse the repository at this point in the history
  • Loading branch information
haohuw committed Jun 23, 2019
1 parent f44f6cf commit 815e2b9
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 12 deletions.
1 change: 1 addition & 0 deletions cpp-package/example/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ inception_inference --symbol <model symbol file in json format>
[--input_shape <dimensions of input image e.g "3 224 224"]
[--mean file containing mean image for normalizing the input image
[--gpu] Specify this option if workflow needs to be run in gpu context
[--enableTRT <Specify this option if workflow needs to be run with TensorRT>]
```
The model json and param file and synset files are required to run this example. The sample command line is as follows:

Expand Down
79 changes: 67 additions & 12 deletions cpp-package/example/inference/inception_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class Predictor {
const std::string& model_params_file,
const Shape& input_shape,
bool gpu_context_type = false,
bool enable_tensorrt = false,
const std::string& synset_file = "",
const std::string& mean_image_file = "");
void PredictImage(const std::string& image_file);
Expand All @@ -62,6 +63,13 @@ class Predictor {
private:
void LoadModel(const std::string& model_json_file);
void LoadParameters(const std::string& model_parameters_file);
void SplitParamMap(const std::map<std::string, NDArray>& paramMap,
std::map<std::string, NDArray>& argParamInTargetContext,
std::map<std::string, NDArray>& auxParamInTargetContext,
Context targetContext);
void ConvertParamMapToTargetContext(const std::map<std::string, NDArray>& paramMap,
std::map<std::string, NDArray>& paramMapInTargetContext,
Context targetContext);
void LoadSynset(const std::string& synset_file);
NDArray LoadInputImage(const std::string& image_file);
void LoadMeanImageData();
Expand All @@ -75,6 +83,7 @@ class Predictor {
std::map<std::string, NDArray> args_map;
std::map<std::string, NDArray> aux_map;
std::vector<std::string> output_labels;
bool enable_tensorrt;
Symbol net;
Executor *executor;
Shape input_shape;
Expand Down Expand Up @@ -104,13 +113,16 @@ Predictor::Predictor(const std::string& model_json_file,
const std::string& model_params_file,
const Shape& input_shape,
bool gpu_context_type,
bool enable_tensorrt,
const std::string& synset_file,
const std::string& mean_image_file):
input_shape(input_shape),
mean_image_file(mean_image_file) {
if (gpu_context_type) {
global_ctx = Context::gpu();
}
this->enable_tensorrt = enable_tensorrt;

// Load the model
LoadModel(model_json_file);

Expand All @@ -135,8 +147,12 @@ Predictor::Predictor(const std::string& model_json_file,

// Create an executor after binding the model to input parameters.
args_map["data"] = NDArray(input_shape, global_ctx, false);
executor = net.SimpleBind(global_ctx, args_map, std::map<std::string, NDArray>(),
std::map<std::string, OpReqType>(), aux_map);
try {
executor = net.SimpleBind(global_ctx, args_map, std::map<std::string, NDArray>(),
std::map<std::string, OpReqType>(), aux_map);
} catch(const std::exception& e) {
LG << "SimpleBind error " << MXGetLastError();
}
}

/*
Expand All @@ -149,6 +165,9 @@ void Predictor::LoadModel(const std::string& model_json_file) {
}
LG << "Loading the model from " << model_json_file << std::endl;
net = Symbol::Load(model_json_file);
if (enable_tensorrt) {
net = net.GetBackendSymbol("TensorRT");
}
}


Expand All @@ -163,20 +182,49 @@ void Predictor::LoadParameters(const std::string& model_parameters_file) {
LG << "Loading the model parameters from " << model_parameters_file << std::endl;
std::map<std::string, NDArray> parameters;
NDArray::Load(model_parameters_file, 0, &parameters);
for (const auto &k : parameters) {
if (k.first.substr(0, 4) == "aux:") {
auto name = k.first.substr(4, k.first.size() - 4);
aux_map[name] = k.second.Copy(global_ctx);
}
if (k.first.substr(0, 4) == "arg:") {
auto name = k.first.substr(4, k.first.size() - 4);
args_map[name] = k.second.Copy(global_ctx);
}
if (enable_tensorrt) {
std::map<std::string, NDArray> intermediate_args_map;
std::map<std::string, NDArray> intermediate_aux_map;
SplitParamMap(parameters, intermediate_args_map, intermediate_aux_map, Context::cpu());
contrib::InitTensorRTParams(net, intermediate_args_map, intermediate_aux_map);
ConvertParamMapToTargetContext(intermediate_args_map, args_map, global_ctx);
ConvertParamMapToTargetContext(intermediate_aux_map, aux_map, global_ctx);
} else {
SplitParamMap(parameters, args_map, aux_map, global_ctx);
}
/*WaitAll is need when we copy data between GPU and the main memory*/
NDArray::WaitAll();
}

/*
* The following function split loaded param map into arg parm
* and aux param with target context
*/
void Predictor::SplitParamMap(const std::map<std::string, NDArray> &paramMap,
std::map<std::string, NDArray> &argParamInTargetContext,
std::map<std::string, NDArray> &auxParamInTargetContext,
Context targetContext) {
for (const auto& pair : paramMap) {
std::string type = pair.first.substr(0, 4);
std::string name = pair.first.substr(4);
if (type == "arg:") {
argParamInTargetContext[name] = pair.second.Copy(targetContext);
} else if (type == "aux:") {
auxParamInTargetContext[name] = pair.second.Copy(targetContext);
}
}
}

/*
* The following function copy the param map into the target context
*/
void Predictor::ConvertParamMapToTargetContext(const std::map<std::string, NDArray> &paramMap,
std::map<std::string, NDArray> &paramMapInTargetContext,
Context targetContext) {
for (const auto& pair : paramMap) {
paramMapInTargetContext[pair.first] = pair.second.Copy(targetContext);
}
}

/*
* The following function loads the synset file.
Expand Down Expand Up @@ -359,6 +407,8 @@ void printUsage() {
<< "[--mean <file containing mean image for normalizing the input image>] "
<< std::endl
<< "[--gpu <Specify this option if workflow needs to be run in gpu context>]"
<< std::endl
<< "[--enableTRT <Specify this option if workflow needs to be run with TensorRT>]"
<< std::endl;
}

Expand All @@ -369,6 +419,7 @@ int main(int argc, char** argv) {
std::string mean_image = "";
std::string input_image = "";
bool gpu_context_type = false;
bool enable_tensorrt = false;

std::string input_shape = "3 224 224";
int index = 1;
Expand All @@ -393,6 +444,9 @@ int main(int argc, char** argv) {
input_shape = (index < argc ? argv[index]:input_shape);
} else if (strcmp("--gpu", argv[index]) == 0) {
gpu_context_type = true;
} else if (strcmp("--enableTRT", argv[index]) == 0) {
gpu_context_type = true; // TensorRT need GPU to run
enable_tensorrt = true;
} else if (strcmp("--help", argv[index]) == 0) {
printUsage();
return 0;
Expand Down Expand Up @@ -425,7 +479,8 @@ int main(int argc, char** argv) {

try {
// Initialize the predictor object
Predictor predict(model_file_json, model_file_params, input_data_shape, gpu_context_type,
Predictor predict(model_file_json, model_file_params, input_data_shape,
gpu_context_type, enable_tensorrt,
synset_file, mean_image);

// Run the forward pass to predict the image.
Expand Down
1 change: 1 addition & 0 deletions cpp-package/include/mxnet-cpp/MxNetCpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@
#include "mxnet-cpp/io.hpp"
#include "mxnet-cpp/metric.h"
#include "mxnet-cpp/initializer.h"
#include "mxnet-cpp/contrib.h"

#endif // MXNET_CPP_MXNETCPP_H_
114 changes: 114 additions & 0 deletions cpp-package/include/mxnet-cpp/contrib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file contrib.h
* \brief utility function to enable some contrib features
* \author Haohuan Wang
*/
#ifndef MXNET_CPP_CONTRIB_H_
#define MXNET_CPP_CONTRIB_H_

#include <iostream>
#include <string>
#include <map>
#include "mxnet-cpp/symbol.h"

namespace mxnet {
namespace cpp {
namespace details {

/*!
* split a string with the given delimiter
* @param str string to be parsed
* @param delimiter delimiter
* @return delimited list of string
*/
inline std::vector<std::string> split(const std::string& str, const std::string& delimiter) {
std::vector<std::string> splitted;
size_t last = 0;
size_t next = 0;
while ((next = str.find(delimiter, last)) != std::string::npos) {
splitted.push_back(str.substr(last, next - last));
last = next + 1;
}
splitted.push_back(str.substr(last));
return splitted;
}

} // namespace details

namespace contrib {

// needs to be same with
// https://github.com/apache/incubator-mxnet/blob/1c874cfc807cee755c38f6486e8e0f4d94416cd8/src/operator/subgraph/tensorrt/tensorrt-inl.h#L190
const static std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER = "subgraph_params_names";
// needs to be same with
// https://github.com/apache/incubator-mxnet/blob/master/src/operator/subgraph/tensorrt/tensorrt.cc#L244
const static std::string TENSORRT_SUBGRAPH_PARAM_PREFIX = "subgraph_param_";
/*!
* this is a mimic to https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/contrib/tensorrt.py#L37
* @param symbol symbol that already called subgraph api
* @param argParams original arg params, params needed by tensorrt will be removed after calling this function
* @param auxParams original aux params, params needed by tensorrt will be removed after calling this function
*/
inline void InitTensorRTParams(const mxnet::cpp::Symbol& symbol,
std::map<std::string, mxnet::cpp::NDArray>& argParams,
std::map<std::string, mxnet::cpp::NDArray>& auxParams) {
mxnet::cpp::Symbol internals = symbol.GetInternals();
mx_uint numSymbol = internals.GetNumOutputs();
for (mx_uint i = 0; i < numSymbol; ++i) {
std::map<std::string, std::string> attrs = internals[i].ListAttributes();
if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
std::string new_params_names;
std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
std::vector<std::string> keys = details::split(attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER], ";");
for (const auto& key : keys) {
if (argParams.find(key) != argParams.end()) {
new_params_names += key + ";";
tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = argParams[key];
argParams.erase(key);
} else if (auxParams.find(key) != auxParams.end()) {
new_params_names += key + ";";
tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = auxParams[key];
auxParams.erase(key);
}
}
std::map<std::string, std::string> new_attrs = {};
for (const auto& kv : tensorrtParams) {
// passing the ndarray address into TRT node attributes to get the weight
std::vector<mx_uint > shape = kv.second.GetShape();
long address = reinterpret_cast<long>(kv.second.GetHandle());
new_attrs[kv.first] = std::to_string(address);
}
if (!new_attrs.empty()) {
internals[i].SetAttributes(new_attrs);
internals[i].SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
new_params_names.substr(0, new_params_names.length() - 1));
}
}
}
}

} // namespace contrib
} // namespace cpp
} // namepsace mxnet

#endif //MXNET_CPP_CONTRIB_H_
17 changes: 17 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ class Symbol {
std::vector<std::string> ListOutputs() const;
/*! \return get the descriptions of auxiliary data for this symbol */
std::vector<std::string> ListAuxiliaryStates() const;
/*! \return get all attributes for this symbol */
std::map<std::string, std::string> ListAttributes() const;
/*!
* \brief set key-value attribute to the symbol
* @param key string represent the key for the attribute
* @param value string represent the value for the attribute
*/
void SetAttribute(const std::string& key, const std::string& value);
/*!
* \brief set a series of key-value attribute to the symbol
* @param attrs string:string map represent the key value attributes
*/
void SetAttributes(const std::map<std::string, std::string>& attrs);
/*! \return get number of outputs for this symbol */
mx_uint GetNumOutputs() const;
/*! \return get the new symbol through subgraph API for this symbol */
mxnet::cpp::Symbol GetBackendSymbol(const std::string& backendName) const;
/*! \return get the name of the symbol */
std::string GetName() const;
/*!
Expand Down
35 changes: 35 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,41 @@ inline std::vector<std::string> Symbol::ListAuxiliaryStates() const {
return ret;
}

inline std::map<std::string, std::string> Symbol::ListAttributes() const {
mx_uint size;
const char** pairs;
CHECK_EQ(MXSymbolListAttrShallow(GetHandle(), &size, &pairs), 0);
std::map<std::string, std::string> attributes;
for (mx_uint i = 0; i < size; ++i) {
// pairs is 2 * size with key, value pairs according to
// https://github.com/apache/incubator-mxnet/blob/master/include/mxnet/c_api.h#L1428
attributes[pairs[2 * i]] = pairs[2 * i + 1];
}
return attributes;
}

inline void Symbol::SetAttribute(const std::string &key, const std::string &value) {
CHECK_EQ(MXSymbolSetAttr(GetHandle(), key.c_str(), value.c_str()), 0);
}

inline void Symbol::SetAttributes(const std::map<std::string, std::string> &attrs) {
for (const auto& kv : attrs) {
SetAttribute(kv.first, kv.second);
}
}

inline mx_uint Symbol::GetNumOutputs() const {
mx_uint numOutputs;
CHECK_EQ(MXSymbolGetNumOutputs(GetHandle(), &numOutputs), 0);
return numOutputs;
}

inline mxnet::cpp::Symbol Symbol::GetBackendSymbol(const std::string &backendName) const {
SymbolHandle symbolHandle;
CHECK_EQ(MXGenBackendSubgraph(GetHandle(), backendName.c_str(), &symbolHandle), 0);
return mxnet::cpp::Symbol(symbolHandle);
}

inline std::string Symbol::GetName() const {
int success;
const char* out_name;
Expand Down

0 comments on commit 815e2b9

Please sign in to comment.