Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-703] Fix incorrect predictions, update onnx-tensorrt
Browse files Browse the repository at this point in the history
Updates IR used to pass subgraphs to ONNX3 v8.
Fixes a number of bugs including crashes.
Adds support for TensorRT 5.
  • Loading branch information
KellenSunderland committed Nov 18, 2018
1 parent 64657c2 commit 6150f4a
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 75 deletions.
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ ifeq ($(ENABLE_TESTCOVERAGE), 1)
endif

ifeq ($(USE_TENSORRT), 1)
CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -I$(TPARTYDIR)/onnx-tensorrt/third_party/onnx/ -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
endif
# -L/usr/local/lib

ifeq ($(DEBUG), 1)
NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
Expand Down
5 changes: 4 additions & 1 deletion ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
#
# Dockerfile to run MXNet on Ubuntu 16.04 for CPU

FROM nvidia/cuda:9.0-cudnn7-devel
FROM nvidia/cuda:10.0-cudnn7-devel

# Avoid interactive package installers.
ENV DEBIAN_FRONTEND noninteractive

WORKDIR /work/deps

Expand Down
4 changes: 2 additions & 2 deletions ci/docker/install/tensorrt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pip3 install gluoncv==0.2.0
pushd .
cd ..
apt-get update
apt-get install -y automake libtool
apt-get install -y automake libtool zip
git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
cd protobuf
./autogen.sh
Expand All @@ -41,7 +41,7 @@ popd

# Install TensorRT
echo "TensorRT build enabled. Installing TensorRT."
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0_1-1_amd64.deb
dpkg -i tensorrt.deb
apt-get update
apt-get install -y --allow-downgrades libnvinfer-dev
Expand Down
2 changes: 1 addition & 1 deletion src/executor/onnx_to_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include "./onnx_to_tensorrt.h"

#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include <NvInfer.h>
#include <google/protobuf/io/coded_stream.h>
Expand Down
2 changes: 1 addition & 1 deletion src/executor/tensorrt_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <mxnet/op_attr_types.h>
#include <mxnet/operator.h>
#include <nnvm/graph_attr_types.h>
#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include "../operator/contrib/nnvm_to_onnx-inl.h"
#include "./exec_pass.h"
Expand Down
2 changes: 1 addition & 1 deletion src/executor/trt_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "trt_graph_executor.h"

#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>
#include <NvInfer.h>
#include "./onnx_to_tensorrt.h"
#include "../operator/contrib/tensorrt-inl.h"
Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <nnvm/pass_functions.h>

#include <NvInfer.h>
#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include <algorithm>
#include <iostream>
Expand Down
126 changes: 62 additions & 64 deletions src/operator/contrib/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,21 @@ namespace nnvm_to_onnx {
op::TRTParam ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
op::TRTParam trt_param;
op::tensorrt::NameToIdx_t trt_input_map;
op::tensorrt::InferenceMap_t trt_output_map;

static std::atomic_ulong subgraph_count = { 0 };
op::TRTParam trt_param;
op::tensorrt::NameToIdx_t trt_input_map;
op::tensorrt::InferenceMap_t trt_output_map;
const nnvm::IndexedGraph& ig = g.indexed_graph();

const auto& storage_types = g.GetAttr<StorageTypeVector>("storage_type");
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");

// TODO(kellens): At the moment this check always passes no matter the weight dtypes used in your
// graph. We should first iterate over datatypes by name and ensure they're valid types
// (fp16 or fp32) and that they're uniform. Then ensure later conversions set tensor types
// correctly in ONNX.
for (auto& e : storage_types) {
if (e != mshadow::kFloat32) {
LOG(FATAL) << "ONNX converter does not support types other than float32 "
Expand All @@ -78,8 +84,13 @@ op::TRTParam ConvertNnvmGraphToOnnx(

ModelProto model_proto;
// Need to determine IR versions and features to support
model_proto.set_ir_version(static_cast<int64>(2));
auto opset_proto = model_proto.add_opset_import();
opset_proto->set_version(static_cast<int64>(8));
model_proto.set_ir_version(static_cast<int64>(3));

GraphProto* graph_proto = model_proto.mutable_graph();
auto subgraph_name_id = subgraph_count.fetch_add(1);
graph_proto->set_name("MXNetTRTSubgraph" + std::to_string(subgraph_name_id));

std::unordered_map<std::string, TShape> placeholder_shapes =
GetPlaceholderShapes(shape_inputs, ig);
Expand Down Expand Up @@ -174,6 +185,20 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
// const bool no_bias = conv_param.no_bias;
const dmlc::optional<int> layout = conv_param.layout;

// dilations
AttributeProto* const dilations = node_proto->add_attribute();
dilations->set_name("dilations");
dilations->set_type(AttributeProto::INTS);
for (const dim_t kval : dilate) {
dilations->add_ints(static_cast<int64>(kval));
}

// group
AttributeProto* const group = node_proto->add_attribute();
group->set_name("group");
group->set_type(AttributeProto::INT);
group->set_i(static_cast<int64>(num_group));

// kernel shape
AttributeProto* const kernel_shape = node_proto->add_attribute();
kernel_shape->set_name("kernel_shape");
Expand All @@ -193,27 +218,13 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
pads->add_ints(static_cast<int64>(kval));
}

// dilations
AttributeProto* const dilations = node_proto->add_attribute();
dilations->set_name("dilations");
dilations->set_type(AttributeProto::INTS);
for (const dim_t kval : dilate) {
dilations->add_ints(static_cast<int64>(kval));
}

// strides
AttributeProto* const strides = node_proto->add_attribute();
strides->set_name("strides");
strides->set_type(AttributeProto::INTS);
for (const dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}

// group
AttributeProto* const group = node_proto->add_attribute();
group->set_name("group");
group->set_type(AttributeProto::INT);
group->set_i(static_cast<int64>(num_group));
} // end ConvertConvolution

void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
Expand Down Expand Up @@ -248,8 +259,12 @@ void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const pads = node_proto->add_attribute();
pads->set_name("pads");
pads->set_type(AttributeProto::INTS);
for (int kval : pad) {
pads->add_ints(static_cast<int64>(kval));

// Convert from MXNet symetric pads to ONNX non-symetric by running through padding twice.
for (int i =0; i < 2; i++) {
for (dim_t kval : pad) {
pads->add_ints(static_cast<int64>(kval));
}
}

// strides
Expand Down Expand Up @@ -313,11 +328,6 @@ void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs,
beta->set_type(AttributeProto::FLOAT);
beta->set_f(1.0f);

AttributeProto* const broadcast = node_proto->add_attribute();
broadcast->set_name("broadcast");
broadcast->set_type(AttributeProto::INT);
broadcast->set_i(1);

AttributeProto* const transA = node_proto->add_attribute();
transA->set_name("transA");
transA->set_type(AttributeProto::INT);
Expand Down Expand Up @@ -369,11 +379,6 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
epsilon->set_type(AttributeProto::FLOAT);
epsilon->set_f(static_cast<float>(param.eps));

AttributeProto* const is_test = node_proto->add_attribute();
is_test->set_name("is_test");
is_test->set_type(AttributeProto::INT);
is_test->set_i(1);

AttributeProto* const momentum = node_proto->add_attribute();
momentum->set_name("momentum");
momentum->set_type(AttributeProto::FLOAT);
Expand All @@ -382,31 +387,16 @@ void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs,
AttributeProto* const spatial = node_proto->add_attribute();
spatial->set_name("spatial");
spatial->set_type(AttributeProto::INT);
spatial->set_i(1);

AttributeProto* const consumed = node_proto->add_attribute();
consumed->set_name("consumed_inputs");
consumed->set_type(AttributeProto::INTS);

for (int i = 0; i < 5; i++) {
int val = (i < 3) ? 0 : 1;
consumed->add_ints(static_cast<int64>(val));
}
// MXNet computes mean and variance per feature for batchnorm. Enabling spatial mode
// (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly
// disable this for MXNet's BatchNorm.
spatial->set_i(0);
}

void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/,
const nnvm::IndexedGraph& /*ig*/,
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
node_proto->set_op_type("Add");
AttributeProto* const axis = node_proto->add_attribute();
axis->set_name("axis");
axis->set_type(AttributeProto::INT);
axis->set_i(1);

AttributeProto* const broadcast = node_proto->add_attribute();
broadcast->set_name("broadcast");
broadcast->set_type(AttributeProto::INT);
broadcast->set_i(0); // 1
}

std::unordered_map<std::string, TShape> GetPlaceholderShapes(
Expand Down Expand Up @@ -459,32 +449,40 @@ void ConvertPlaceholder(
void ConvertConstant(
GraphProto* const graph_proto, const std::string& node_name,
std::unordered_map<std::string, NDArray>* const shared_buffer) {
NodeProto* const node_proto = graph_proto->add_node();
node_proto->set_name(node_name);
node_proto->add_output(node_name);
node_proto->set_op_type("Constant");
TensorProto* const initializer_proto = graph_proto->add_initializer();

// Create initializer for constants
initializer_proto->set_name(node_name);
// TODO(kellens): convert to fp16 if needed.
initializer_proto->set_data_type(TensorProto_DataType_FLOAT);

const NDArray nd = shared_buffer->find(node_name)->second;
const TBlob& blob = nd.data();
const TShape shape = blob.shape_;
const int32_t size = shape.Size();

for (auto& dim : shape) {
initializer_proto->add_dims(static_cast<int64>(dim));
}

auto size = shape.Size();
// TODO(kellens): Note hard coded float32 size assumed.
std::shared_ptr<float> shared_data_ptr(new float[size]);
float* const data_ptr = shared_data_ptr.get();
nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);

AttributeProto* const tensor_attr = node_proto->add_attribute();
tensor_attr->set_name("value");
tensor_attr->set_type(AttributeProto::TENSOR);

TensorProto* const tensor_proto = tensor_attr->mutable_t();
tensor_proto->set_data_type(TensorProto_DataType_FLOAT);
for (auto& dim : shape) {
tensor_proto->add_dims(static_cast<int64>(dim));
for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
initializer_proto->add_float_data(data_ptr[blob_idx]);
}

for (int blob_idx = 0; blob_idx < size; ++blob_idx) {
tensor_proto->add_float_data(data_ptr[blob_idx]);
// Create inputs for constants.
ValueInfoProto* const input_proto = graph_proto->add_input();
input_proto->set_name(node_name);

// TODO(kellens): (fp16 support)
input_proto->mutable_type()->mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
for (auto& dim : shape) {
auto new_dim = input_proto->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim();
new_dim->set_dim_value(static_cast<int64>(dim));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/operator/contrib/tensorrt-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <nnvm/pass_functions.h>

#include <NvInfer.h>
#include <onnx/onnx.pb.h>
#include <onnx/onnx_pb.h>

#include <algorithm>
#include <iostream>
Expand Down

0 comments on commit 6150f4a

Please sign in to comment.